mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 02:35:42 +08:00
Merge branch 'main' into wye-refactor-quant-folder
This commit is contained in:
commit
e925187f6d
2
.github/workflows/bc-lint.yml
vendored
2
.github/workflows/bc-lint.yml
vendored
@ -6,6 +6,8 @@ on:
|
|||||||
- opened
|
- opened
|
||||||
- synchronize
|
- synchronize
|
||||||
- reopened
|
- reopened
|
||||||
|
- labeled
|
||||||
|
- unlabeled
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
bc_lint:
|
bc_lint:
|
||||||
|
|||||||
@ -81,7 +81,7 @@ vLLM is flexible and easy to use with:
|
|||||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||||
- Streaming outputs
|
- Streaming outputs
|
||||||
- OpenAI-compatible API server
|
- OpenAI-compatible API server
|
||||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron
|
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
|
||||||
- Prefix caching support
|
- Prefix caching support
|
||||||
- Multi-LoRA support
|
- Multi-LoRA support
|
||||||
|
|
||||||
|
|||||||
@ -1,77 +1,675 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import time
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
silu_mul_fp8_quant_deep_gemm,
|
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
|
|
||||||
|
|
||||||
def benchmark(E, T, H, G=128, runs=50):
|
@triton.jit
|
||||||
current_platform.seed_everything(42)
|
def _silu_mul_fp8_quant_deep_gemm(
|
||||||
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
|
# Pointers ------------------------------------------------------------
|
||||||
tokens_per_expert = torch.randint(
|
input_ptr, # 16-bit activations (E, T, 2*H)
|
||||||
T // 2, T, size=(E,), dtype=torch.int32, device="cuda"
|
y_q_ptr, # fp8 quantized activations (E, T, H)
|
||||||
|
y_s_ptr, # 16-bit scales (E, T, G)
|
||||||
|
counts_ptr, # int32 num tokens per expert (E)
|
||||||
|
# Sizes ---------------------------------------------------------------
|
||||||
|
H: tl.constexpr, # hidden dimension (per output)
|
||||||
|
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
|
||||||
|
# Strides for input (elements) ---------------------------------------
|
||||||
|
stride_i_e,
|
||||||
|
stride_i_t,
|
||||||
|
stride_i_h,
|
||||||
|
# Strides for y_q (elements) -----------------------------------------
|
||||||
|
stride_yq_e,
|
||||||
|
stride_yq_t,
|
||||||
|
stride_yq_h,
|
||||||
|
# Strides for y_s (elements) -----------------------------------------
|
||||||
|
stride_ys_e,
|
||||||
|
stride_ys_t,
|
||||||
|
stride_ys_g,
|
||||||
|
# Stride for counts (elements)
|
||||||
|
stride_counts_e,
|
||||||
|
# Numeric params ------------------------------------------------------
|
||||||
|
eps: tl.constexpr,
|
||||||
|
fp8_min: tl.constexpr,
|
||||||
|
fp8_max: tl.constexpr,
|
||||||
|
use_ue8m0: tl.constexpr,
|
||||||
|
# Meta ---------------------------------------------------------------
|
||||||
|
BLOCK: tl.constexpr,
|
||||||
|
NUM_STAGES: tl.constexpr,
|
||||||
|
):
|
||||||
|
G = H // GROUP_SIZE
|
||||||
|
|
||||||
|
# map program id -> (e, g)
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
e = pid // G
|
||||||
|
g = pid % G
|
||||||
|
|
||||||
|
e = e.to(tl.int64)
|
||||||
|
g = g.to(tl.int64)
|
||||||
|
|
||||||
|
# number of valid tokens for this expert
|
||||||
|
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
|
||||||
|
|
||||||
|
cols = tl.arange(0, BLOCK).to(tl.int64)
|
||||||
|
mask = cols < BLOCK
|
||||||
|
|
||||||
|
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
|
||||||
|
base_gate_offset = base_input_offset + cols * stride_i_h
|
||||||
|
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
|
||||||
|
base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
|
||||||
|
base_ys_offset = e * stride_ys_e + g * stride_ys_g
|
||||||
|
|
||||||
|
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
|
||||||
|
gate = tl.load(
|
||||||
|
input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0
|
||||||
|
).to(tl.float32)
|
||||||
|
up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0)
|
||||||
|
|
||||||
|
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
|
||||||
|
y = gate * up
|
||||||
|
|
||||||
|
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
|
||||||
|
if use_ue8m0:
|
||||||
|
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
|
||||||
|
|
||||||
|
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
|
||||||
|
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
||||||
|
|
||||||
|
|
||||||
|
def silu_mul_fp8_quant_deep_gemm_triton(
|
||||||
|
y: torch.Tensor, # (E, T, 2*H)
|
||||||
|
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||||
|
num_parallel_tokens,
|
||||||
|
group_size: int = 128,
|
||||||
|
eps: float = 1e-10,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||||
|
|
||||||
|
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||||
|
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||||
|
|
||||||
|
Returns `(y_q, y_s)` where
|
||||||
|
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||||
|
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||||
|
"""
|
||||||
|
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||||
|
E, T, H2 = y.shape
|
||||||
|
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
|
||||||
|
H = H2 // 2
|
||||||
|
G = (H + group_size - 1) // group_size
|
||||||
|
assert H % group_size == 0, "H must be divisible by group_size"
|
||||||
|
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, (
|
||||||
|
"tokens_per_expert must be shape (E,)"
|
||||||
|
)
|
||||||
|
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
|
||||||
|
|
||||||
|
# allocate outputs
|
||||||
|
fp8_dtype = torch.float8_e4m3fn
|
||||||
|
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||||
|
|
||||||
|
# strides (elements)
|
||||||
|
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||||
|
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||||
|
|
||||||
|
# desired scale strides (elements): (T*G, 1, T)
|
||||||
|
stride_ys_e = T * G
|
||||||
|
stride_ys_t = 1
|
||||||
|
stride_ys_g = T
|
||||||
|
y_s = torch.empty_strided(
|
||||||
|
(E, T, G),
|
||||||
|
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=y.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||||
|
|
||||||
|
# Static grid over experts and H-groups.
|
||||||
|
# A loop inside the kernel handles the token dim
|
||||||
|
grid = (E * G,)
|
||||||
|
|
||||||
|
f_info = torch.finfo(fp8_dtype)
|
||||||
|
fp8_max = f_info.max
|
||||||
|
fp8_min = f_info.min
|
||||||
|
|
||||||
|
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||||
|
y,
|
||||||
|
y_q,
|
||||||
|
y_s,
|
||||||
|
tokens_per_expert,
|
||||||
|
H,
|
||||||
|
group_size,
|
||||||
|
stride_i_e,
|
||||||
|
stride_i_t,
|
||||||
|
stride_i_h,
|
||||||
|
stride_yq_e,
|
||||||
|
stride_yq_t,
|
||||||
|
stride_yq_h,
|
||||||
|
stride_ys_e,
|
||||||
|
stride_ys_t,
|
||||||
|
stride_ys_g,
|
||||||
|
stride_cnt_e,
|
||||||
|
eps,
|
||||||
|
fp8_min,
|
||||||
|
fp8_max,
|
||||||
|
is_deep_gemm_e8m0_used(),
|
||||||
|
BLOCK=group_size,
|
||||||
|
NUM_STAGES=4,
|
||||||
|
num_warps=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return y_q, y_s
|
||||||
|
|
||||||
|
|
||||||
|
# Parse generation strategies
|
||||||
|
strategies = ["uniform", "max_t", "first_t"]
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
kernel: Callable,
|
||||||
|
E: int,
|
||||||
|
T: int,
|
||||||
|
H: int,
|
||||||
|
total_tokens: int,
|
||||||
|
num_parallel_tokens: int = 64,
|
||||||
|
G: int = 128,
|
||||||
|
runs: int = 200,
|
||||||
|
num_warmups: int = 20,
|
||||||
|
gen_strategy: str = "default",
|
||||||
|
iterations_per_run: int = 20,
|
||||||
|
):
|
||||||
|
def generate_data(seed_offset=0):
|
||||||
|
"""Generate input data with given seed offset"""
|
||||||
|
current_platform.seed_everything(42 + seed_offset)
|
||||||
|
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||||
|
|
||||||
|
if gen_strategy == "uniform":
|
||||||
|
r = torch.rand(size=(E,), device="cuda")
|
||||||
|
r /= r.sum()
|
||||||
|
r *= total_tokens
|
||||||
|
tokens_per_expert = r.int()
|
||||||
|
tokens_per_expert = torch.minimum(
|
||||||
|
tokens_per_expert,
|
||||||
|
torch.ones((E,), device=r.device, dtype=torch.int) * T,
|
||||||
|
)
|
||||||
|
elif gen_strategy == "max_t":
|
||||||
|
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
|
||||||
|
tokens_per_expert.fill_(total_tokens / E)
|
||||||
|
elif gen_strategy == "first_t":
|
||||||
|
tokens_per_expert = torch.zeros(size=(E,), dtype=torch.int32, device="cuda")
|
||||||
|
tokens_per_expert[0] = min(T, total_tokens)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown generation strategy: {gen_strategy}")
|
||||||
|
return y, tokens_per_expert
|
||||||
|
|
||||||
|
dataset_count = 4
|
||||||
|
# Pre-generate different input matrices for each iteration to avoid cache effects
|
||||||
|
data_sets = [generate_data(i) for i in range(dataset_count)]
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
for _ in range(10):
|
y, tokens_per_expert = data_sets[0]
|
||||||
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
|
for _ in range(num_warmups):
|
||||||
torch.cuda.synchronize()
|
kernel(
|
||||||
|
y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
# Benchmark
|
# Benchmark
|
||||||
torch.cuda.synchronize()
|
latencies: list[float] = []
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(runs):
|
for _ in range(runs):
|
||||||
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
|
torch.cuda.synchronize()
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
avg_time = (time.perf_counter() - start) / runs * 1000
|
start_event.record()
|
||||||
|
for i in range(iterations_per_run):
|
||||||
|
y, tokens_per_expert = data_sets[i % dataset_count]
|
||||||
|
kernel(
|
||||||
|
y,
|
||||||
|
tokens_per_expert,
|
||||||
|
num_parallel_tokens=num_parallel_tokens,
|
||||||
|
group_size=G,
|
||||||
|
)
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
|
||||||
# Calculate actual work done (only count valid tokens)
|
total_time_ms = start_event.elapsed_time(end_event)
|
||||||
|
per_iter_time_ms = total_time_ms / iterations_per_run
|
||||||
|
latencies.append(per_iter_time_ms)
|
||||||
|
|
||||||
|
# Use median instead of average for better outlier handling
|
||||||
|
median_time_ms = np.median(latencies)
|
||||||
|
median_time_s = median_time_ms / 1000
|
||||||
|
|
||||||
|
# Calculate actual work done (using first dataset for consistency)
|
||||||
|
_, tokens_per_expert = data_sets[0]
|
||||||
actual_tokens = tokens_per_expert.sum().item()
|
actual_tokens = tokens_per_expert.sum().item()
|
||||||
actual_elements = actual_tokens * H
|
actual_elements = actual_tokens * H
|
||||||
|
|
||||||
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
|
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
|
||||||
ops_per_element = 8
|
ops_per_element = 8
|
||||||
total_ops = actual_elements * ops_per_element
|
total_ops = actual_elements * ops_per_element
|
||||||
gflops = total_ops / (avg_time / 1000) / 1e9
|
gflops = total_ops / median_time_s / 1e9
|
||||||
|
|
||||||
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
|
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
|
||||||
input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
|
input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
|
||||||
output_bytes = actual_tokens * H * 1 # H fp8 outputs
|
output_bytes = actual_tokens * H * 1 # H fp8 outputs
|
||||||
scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
|
scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
|
||||||
total_bytes = input_bytes + output_bytes + scale_bytes
|
total_bytes = input_bytes + output_bytes + scale_bytes
|
||||||
memory_bw = total_bytes / (avg_time / 1000) / 1e9
|
memory_bw = total_bytes / median_time_s / 1e9
|
||||||
|
|
||||||
return avg_time, gflops, memory_bw
|
HOPPER_BANDWIDTH_TBPS = 3.35
|
||||||
|
return (
|
||||||
|
median_time_ms,
|
||||||
|
gflops,
|
||||||
|
memory_bw,
|
||||||
|
(memory_bw / (HOPPER_BANDWIDTH_TBPS * 1024)) * 100,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_comparison_plot(
|
||||||
|
ratio, cuda_times, baseline_times, config_labels, strategy_name, id
|
||||||
|
):
|
||||||
|
"""Create a comparison plot for a specific generation strategy"""
|
||||||
|
fig, ax = plt.subplots(1, 1, figsize=(16, 6))
|
||||||
|
|
||||||
|
# Configure x-axis positions
|
||||||
|
x = np.arange(len(config_labels))
|
||||||
|
width = 0.35
|
||||||
|
|
||||||
|
# Execution Time plot (lower is better)
|
||||||
|
ax.bar(
|
||||||
|
x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue"
|
||||||
|
)
|
||||||
|
ax.bar(
|
||||||
|
x + width / 2,
|
||||||
|
baseline_times,
|
||||||
|
width,
|
||||||
|
label="Baseline",
|
||||||
|
alpha=0.8,
|
||||||
|
color="orange",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add speedup labels over each bar pair
|
||||||
|
for i in range(len(x)):
|
||||||
|
speedup = ratio[i]
|
||||||
|
max_height = max(cuda_times[i], baseline_times[i])
|
||||||
|
ax.text(
|
||||||
|
x[i],
|
||||||
|
max_height + max_height * 0.02,
|
||||||
|
f"{speedup:.2f}x",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontweight="bold",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlabel("Configuration")
|
||||||
|
ax.set_ylabel("% Utilization")
|
||||||
|
ax.set_title(
|
||||||
|
f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)"
|
||||||
|
)
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(config_labels, rotation=45, ha="right")
|
||||||
|
ax.legend()
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
return fig, ax
|
||||||
|
|
||||||
|
|
||||||
|
def create_combined_plot(all_results):
|
||||||
|
"""Create a combined plot with all strategies in one PNG"""
|
||||||
|
num_strategies = len(all_results)
|
||||||
|
fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies))
|
||||||
|
|
||||||
|
if num_strategies == 1:
|
||||||
|
axes = [axes]
|
||||||
|
|
||||||
|
for idx, (
|
||||||
|
strategy_name,
|
||||||
|
ratio,
|
||||||
|
cuda_times,
|
||||||
|
baseline_times,
|
||||||
|
config_labels,
|
||||||
|
) in enumerate(all_results):
|
||||||
|
ax = axes[idx]
|
||||||
|
|
||||||
|
# Configure x-axis positions
|
||||||
|
x = np.arange(len(config_labels))
|
||||||
|
width = 0.35
|
||||||
|
|
||||||
|
# Execution Time plot (lower is better)
|
||||||
|
ax.bar(
|
||||||
|
x - width / 2,
|
||||||
|
cuda_times,
|
||||||
|
width,
|
||||||
|
label="CUDA Kernel",
|
||||||
|
alpha=0.8,
|
||||||
|
color="blue",
|
||||||
|
)
|
||||||
|
ax.bar(
|
||||||
|
x + width / 2,
|
||||||
|
baseline_times,
|
||||||
|
width,
|
||||||
|
label="Baseline",
|
||||||
|
alpha=0.8,
|
||||||
|
color="orange",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add speedup labels over each bar pair
|
||||||
|
for i in range(len(x)):
|
||||||
|
speedup = ratio[i]
|
||||||
|
max_height = max(cuda_times[i], baseline_times[i])
|
||||||
|
ax.text(
|
||||||
|
x[i],
|
||||||
|
max_height + max_height * 0.02,
|
||||||
|
f"{speedup:.2f}x",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontweight="bold",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlabel("Configuration")
|
||||||
|
ax.set_ylabel("% Utilization")
|
||||||
|
ax.set_title(
|
||||||
|
f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)"
|
||||||
|
)
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(config_labels, rotation=45, ha="right")
|
||||||
|
ax.legend()
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
filename = "../../silu_bench/silu_benchmark_combined.png"
|
||||||
|
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
outer_dim = 7168
|
||||||
configs = [
|
configs = [
|
||||||
(8, 32, 1024),
|
|
||||||
(16, 64, 2048),
|
|
||||||
(32, 128, 4096),
|
|
||||||
# DeepSeekV3 Configs
|
# DeepSeekV3 Configs
|
||||||
(256, 16, 7168),
|
(8, 1024, 7168),
|
||||||
(256, 32, 7168),
|
# DeepSeekV3 Configs
|
||||||
(256, 64, 7168),
|
(32, 1024, 7168),
|
||||||
(256, 128, 7168),
|
# DeepSeekV3 Configs
|
||||||
(256, 256, 7168),
|
|
||||||
(256, 512, 7168),
|
|
||||||
(256, 1024, 7168),
|
(256, 1024, 7168),
|
||||||
]
|
]
|
||||||
|
|
||||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
runs = 100
|
||||||
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
|
num_warmups = 20
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
for E, T, H in configs:
|
strategy_descriptions = {
|
||||||
try:
|
"uniform": "Uniform Random",
|
||||||
time_ms, gflops, gbps = benchmark(E, T, H)
|
"max_t": "Even Assignment",
|
||||||
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
|
"first_t": "experts[0] = T, experts[1:] = 0",
|
||||||
except Exception:
|
}
|
||||||
print(f"E={E:3d},T={T:4d},H={H:4d} FAILED")
|
|
||||||
|
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||||
|
print(f"Testing strategies: {', '.join(strategies)}")
|
||||||
|
print(f"Configurations: {len(configs)} configs")
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
# Run benchmarks for each strategy
|
||||||
|
for id, strategy in enumerate(strategies):
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(f"Testing strategy: {strategy_descriptions[strategy]}")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
|
# Collect benchmark data for both algorithms
|
||||||
|
config_labels = []
|
||||||
|
config_x_axis = []
|
||||||
|
all_cuda_results = []
|
||||||
|
all_baseline_results = []
|
||||||
|
all_ratios = []
|
||||||
|
|
||||||
|
for E, T, H in configs:
|
||||||
|
total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E]
|
||||||
|
config_x_axis.append(total_tokens_config)
|
||||||
|
|
||||||
|
cuda_results = []
|
||||||
|
baseline_results = []
|
||||||
|
ratios = []
|
||||||
|
|
||||||
|
for total_tokens in total_tokens_config:
|
||||||
|
config_label = f"E={E},T={T},H={H},TT={total_tokens}"
|
||||||
|
config_labels.append(config_label)
|
||||||
|
|
||||||
|
# CUDA kernel results
|
||||||
|
time_ms_cuda, gflops, gbps, perc = benchmark(
|
||||||
|
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||||
|
E,
|
||||||
|
T,
|
||||||
|
H,
|
||||||
|
total_tokens,
|
||||||
|
runs=runs,
|
||||||
|
num_warmups=num_warmups,
|
||||||
|
gen_strategy=strategy,
|
||||||
|
)
|
||||||
|
cuda_results.append((time_ms_cuda, gflops, gbps, perc))
|
||||||
|
|
||||||
|
# Baseline results
|
||||||
|
time_ms_triton, gflops, gbps, perc = benchmark(
|
||||||
|
silu_mul_fp8_quant_deep_gemm_triton,
|
||||||
|
E,
|
||||||
|
T,
|
||||||
|
H,
|
||||||
|
total_tokens,
|
||||||
|
runs=runs,
|
||||||
|
num_warmups=num_warmups,
|
||||||
|
gen_strategy=strategy,
|
||||||
|
)
|
||||||
|
baseline_results.append((time_ms_triton, gflops, gbps, perc))
|
||||||
|
ratios.append(time_ms_triton / time_ms_cuda)
|
||||||
|
|
||||||
|
print(f"Completed: {config_label}")
|
||||||
|
all_cuda_results.append(cuda_results)
|
||||||
|
all_baseline_results.append(baseline_results)
|
||||||
|
all_ratios.append(ratios)
|
||||||
|
|
||||||
|
# Store results for combined plotting
|
||||||
|
all_results.append(
|
||||||
|
(
|
||||||
|
strategy_descriptions[strategy],
|
||||||
|
all_ratios,
|
||||||
|
all_cuda_results,
|
||||||
|
all_baseline_results,
|
||||||
|
config_labels,
|
||||||
|
config_x_axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print summary table for this strategy
|
||||||
|
print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
|
||||||
|
print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
for i, (E, T, H) in enumerate(configs):
|
||||||
|
speedup = baseline_results[i][0] / cuda_results[i][0]
|
||||||
|
config_label = f"E={E:3d},T={T:4d},H={H:4d}"
|
||||||
|
print(
|
||||||
|
f"{config_label:<20} {cuda_results[i][0]:8.5f} "
|
||||||
|
f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_total_tokens_plot(all_results):
|
||||||
|
num_strategies = len(all_results)
|
||||||
|
num_configs = len(configs)
|
||||||
|
|
||||||
|
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
|
||||||
|
fig, axs = plt.subplots(
|
||||||
|
num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add main title to the entire figure
|
||||||
|
fig.suptitle(
|
||||||
|
"Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)",
|
||||||
|
fontsize=16,
|
||||||
|
fontweight="bold",
|
||||||
|
y=0.98,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle single strategy case
|
||||||
|
if num_strategies == 1:
|
||||||
|
axs = axs.reshape(1, -1)
|
||||||
|
|
||||||
|
# Handle single config case
|
||||||
|
if num_configs == 1:
|
||||||
|
axs = axs.reshape(-1, 2)
|
||||||
|
|
||||||
|
for strategy_idx, result in enumerate(all_results):
|
||||||
|
(
|
||||||
|
strategy_name,
|
||||||
|
all_ratios,
|
||||||
|
all_cuda_results,
|
||||||
|
all_baseline_results,
|
||||||
|
config_labels,
|
||||||
|
config_x_axis,
|
||||||
|
) = result
|
||||||
|
|
||||||
|
for config_idx in range(num_configs):
|
||||||
|
# Speedup plot (left column)
|
||||||
|
ax_speedup = axs[strategy_idx, config_idx * 2]
|
||||||
|
# Bandwidth plot (right column)
|
||||||
|
ax_bandwidth = axs[strategy_idx, config_idx * 2 + 1]
|
||||||
|
|
||||||
|
E, T, H = configs[config_idx]
|
||||||
|
ratios = all_ratios[config_idx]
|
||||||
|
total_tokens_values = config_x_axis[config_idx]
|
||||||
|
|
||||||
|
# Extract CUDA and Triton bandwidth percentages
|
||||||
|
cuda_bandwidth_percentages = [
|
||||||
|
result[3] for result in all_cuda_results[config_idx]
|
||||||
|
]
|
||||||
|
triton_bandwidth_percentages = [
|
||||||
|
result[3] for result in all_baseline_results[config_idx]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Plot speedup ratios vs total tokens (left plot)
|
||||||
|
ax_speedup.plot(
|
||||||
|
total_tokens_values, ratios, "bo-", linewidth=3, markersize=8
|
||||||
|
)
|
||||||
|
ax_speedup.set_title(
|
||||||
|
f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}",
|
||||||
|
fontsize=12,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||||
|
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
|
||||||
|
ax_speedup.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
ax_bandwidth.plot(
|
||||||
|
total_tokens_values,
|
||||||
|
cuda_bandwidth_percentages,
|
||||||
|
"ro-",
|
||||||
|
linewidth=3,
|
||||||
|
markersize=8,
|
||||||
|
label="CUDA",
|
||||||
|
)
|
||||||
|
ax_bandwidth.plot(
|
||||||
|
total_tokens_values,
|
||||||
|
triton_bandwidth_percentages,
|
||||||
|
"go-",
|
||||||
|
linewidth=3,
|
||||||
|
markersize=8,
|
||||||
|
label="Triton",
|
||||||
|
)
|
||||||
|
ax_bandwidth.set_title(
|
||||||
|
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
|
||||||
|
fontsize=12,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax_bandwidth.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||||
|
ax_bandwidth.set_ylabel(
|
||||||
|
"% of Peak Bandwidth", fontweight="bold", fontsize=11
|
||||||
|
)
|
||||||
|
ax_bandwidth.legend(prop={"weight": "bold"})
|
||||||
|
ax_bandwidth.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Format x-axis labels for both plots
|
||||||
|
for ax in [ax_speedup, ax_bandwidth]:
|
||||||
|
ax.set_xticks(total_tokens_values)
|
||||||
|
ax.set_xticklabels(
|
||||||
|
[
|
||||||
|
f"{tt // 1000}K" if tt >= 1000 else str(tt)
|
||||||
|
for tt in total_tokens_values
|
||||||
|
],
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
# Make tick labels bold
|
||||||
|
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
||||||
|
label.set_fontweight("bold")
|
||||||
|
|
||||||
|
# Add value labels on speedup points
|
||||||
|
for x, y in zip(total_tokens_values, ratios):
|
||||||
|
ax_speedup.annotate(
|
||||||
|
f"{y:.2f}x",
|
||||||
|
(x, y),
|
||||||
|
textcoords="offset points",
|
||||||
|
xytext=(0, 12),
|
||||||
|
ha="center",
|
||||||
|
fontsize=10,
|
||||||
|
fontweight="bold",
|
||||||
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add value labels on CUDA bandwidth points
|
||||||
|
for x, y in zip(total_tokens_values, cuda_bandwidth_percentages):
|
||||||
|
ax_bandwidth.annotate(
|
||||||
|
f"{y:.1f}%",
|
||||||
|
(x, y),
|
||||||
|
textcoords="offset points",
|
||||||
|
xytext=(0, 12),
|
||||||
|
ha="center",
|
||||||
|
fontsize=9,
|
||||||
|
fontweight="bold",
|
||||||
|
bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add value labels on Triton bandwidth points
|
||||||
|
for x, y in zip(total_tokens_values, triton_bandwidth_percentages):
|
||||||
|
ax_bandwidth.annotate(
|
||||||
|
f"{y:.1f}%",
|
||||||
|
(x, y),
|
||||||
|
textcoords="offset points",
|
||||||
|
xytext=(0, -15),
|
||||||
|
ha="center",
|
||||||
|
fontsize=9,
|
||||||
|
fontweight="bold",
|
||||||
|
bbox=dict(boxstyle="round,pad=0.2", facecolor="green", alpha=0.3),
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.subplots_adjust(top=0.93) # Make room for main title
|
||||||
|
filename = "silu_benchmark_total_tokens.png"
|
||||||
|
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
# Create combined plot with all strategies
|
||||||
|
combined_plot_filename = create_total_tokens_plot(all_results)
|
||||||
|
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print("Benchmark Complete!")
|
||||||
|
print(f"Generated combined plot: {combined_plot_filename}")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|||||||
@ -43,6 +43,7 @@ void sm100_cutlass_mla_decode(
|
|||||||
torch::Tensor const& seq_lens,
|
torch::Tensor const& seq_lens,
|
||||||
torch::Tensor const& page_table,
|
torch::Tensor const& page_table,
|
||||||
torch::Tensor const& workspace,
|
torch::Tensor const& workspace,
|
||||||
|
double sm_scale,
|
||||||
int64_t num_kv_splits) {
|
int64_t num_kv_splits) {
|
||||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||||
}
|
}
|
||||||
|
|||||||
12
csrc/ops.h
12
csrc/ops.h
@ -122,12 +122,6 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
|||||||
std::optional<torch::Tensor> key, int64_t head_size,
|
std::optional<torch::Tensor> key, int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||||
|
|
||||||
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
|
||||||
std::optional<torch::Tensor> key,
|
|
||||||
int64_t head_size, torch::Tensor& cos_sin_cache,
|
|
||||||
bool is_neox, int64_t rot_dim,
|
|
||||||
torch::Tensor& cos_sin_cache_offsets);
|
|
||||||
|
|
||||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
@ -139,6 +133,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
|
|||||||
torch::Tensor& input,
|
torch::Tensor& input,
|
||||||
torch::Tensor& input_global_scale);
|
torch::Tensor& input_global_scale);
|
||||||
#endif
|
#endif
|
||||||
|
void silu_mul_fp8_quant_deep_gemm_cuda(
|
||||||
|
const at::Tensor& input, // (E, T, 2*H)
|
||||||
|
const at::Tensor& counts, // (E)
|
||||||
|
at::Tensor& y_q, // (E, T, H) [OUT]
|
||||||
|
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||||
|
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
|
||||||
|
|
||||||
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
|
|||||||
@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel(
|
|||||||
token_idx, query_stride, key_stride, head_stride);
|
token_idx, query_stride, key_stride, head_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
|
||||||
__global__ void batched_rotary_embedding_kernel(
|
|
||||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
|
||||||
// [num_tokens]
|
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
|
||||||
// head_size] or [num_tokens, num_heads,
|
|
||||||
// head_size]
|
|
||||||
scalar_t* __restrict__ key, // nullptr or
|
|
||||||
// [batch_size, seq_len, num_kv_heads,
|
|
||||||
// head_size] or [num_tokens, num_kv_heads,
|
|
||||||
// head_size]
|
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
|
||||||
// 2]
|
|
||||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
|
||||||
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
|
||||||
const int head_size) {
|
|
||||||
// Each thread block is responsible for one token.
|
|
||||||
const int token_idx = blockIdx.x;
|
|
||||||
int64_t pos = positions[token_idx];
|
|
||||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
|
||||||
const scalar_t* cache_ptr =
|
|
||||||
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
|
||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
|
||||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
|
||||||
token_idx, query_stride, key_stride, head_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rotary_embedding(
|
void rotary_embedding(
|
||||||
@ -211,96 +182,3 @@ void rotary_embedding(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
Batched version of rotary embedding, pack multiple LoRAs together
|
|
||||||
and process in batched manner.
|
|
||||||
*/
|
|
||||||
void batched_rotary_embedding(
|
|
||||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
|
||||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
|
||||||
// [num_tokens, num_heads * head_size] or
|
|
||||||
// [batch_size, seq_len, num_heads, head_size] or
|
|
||||||
// [num_tokens, num_heads, head_size]
|
|
||||||
std::optional<torch::Tensor>
|
|
||||||
key, // null or
|
|
||||||
// [batch_size, seq_len, num_kv_heads * head_size] or
|
|
||||||
// [num_tokens, num_kv_heads * head_size] or
|
|
||||||
// [batch_size, seq_len, num_heads, head_size] or
|
|
||||||
// [num_tokens, num_heads, head_size]
|
|
||||||
int64_t head_size,
|
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
|
||||||
bool is_neox, int64_t rot_dim,
|
|
||||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
|
|
||||||
) {
|
|
||||||
// num_tokens = batch_size * seq_len
|
|
||||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
|
||||||
TORCH_CHECK(
|
|
||||||
positions.size(0) == num_tokens || positions.numel() == num_tokens,
|
|
||||||
"positions must have the same num_tokens or batch_size as "
|
|
||||||
"cos_sin_cache_offsets");
|
|
||||||
|
|
||||||
int positions_ndim = positions.dim();
|
|
||||||
// Make sure num_tokens dim is consistent across positions, query, and key
|
|
||||||
TORCH_CHECK(
|
|
||||||
positions_ndim == 1 || positions_ndim == 2,
|
|
||||||
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
|
||||||
if (positions_ndim == 1) {
|
|
||||||
TORCH_CHECK(query.size(0) == positions.size(0) &&
|
|
||||||
(!key.has_value() || key->size(0) == positions.size(0)),
|
|
||||||
"query, key and positions must have the same number of tokens");
|
|
||||||
}
|
|
||||||
if (positions_ndim == 2) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
query.size(0) == positions.size(0) &&
|
|
||||||
(!key.has_value() || key->size(0) == positions.size(0)) &&
|
|
||||||
query.size(1) == positions.size(1) &&
|
|
||||||
(!key.has_value() || key->size(1) == positions.size(1)),
|
|
||||||
"query, key and positions must have the same batch_size and seq_len");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure head_size is valid for query and key
|
|
||||||
int query_hidden_size = query.numel() / num_tokens;
|
|
||||||
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
|
|
||||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
|
||||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
|
||||||
|
|
||||||
// Make sure query and key have concistent number of heads
|
|
||||||
int num_heads = query_hidden_size / head_size;
|
|
||||||
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
|
|
||||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
|
||||||
|
|
||||||
int seq_dim_idx = positions_ndim - 1;
|
|
||||||
int64_t query_stride = query.stride(seq_dim_idx);
|
|
||||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
|
||||||
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
|
||||||
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
|
||||||
// head_size
|
|
||||||
int query_ndim = query.dim();
|
|
||||||
int64_t head_stride =
|
|
||||||
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
|
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
|
||||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
|
||||||
if (is_neox) {
|
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, true>
|
|
||||||
<<<grid, block, 0, stream>>>(
|
|
||||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
|
||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
|
||||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
|
||||||
} else {
|
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
|
||||||
<<<grid, block, 0, stream>>>(
|
|
||||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
|
||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
|
||||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|||||||
@ -9,6 +9,26 @@
|
|||||||
|
|
||||||
#include "quantization/w8a8/fp8/common.cuh"
|
#include "quantization/w8a8/fp8/common.cuh"
|
||||||
|
|
||||||
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#else
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <hip/hip_fp8.h>
|
||||||
|
|
||||||
|
typedef __hip_bfloat162 __nv_bfloat162;
|
||||||
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
|
typedef __hip_bfloat16_raw __nv_bfloat16_raw;
|
||||||
|
|
||||||
|
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
|
||||||
|
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "core/registration.h"
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -87,6 +107,337 @@ __global__ void act_and_mul_quant_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ float silu(float x) {
|
||||||
|
return (__fdividef(x, (1.f + expf(-x))));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ float2 silu2(float2 x) {
|
||||||
|
return make_float2(silu(x.x), silu(x.y));
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
__device__ __forceinline__ float warp_max(float v) {
|
||||||
|
static constexpr unsigned FULL_MASK = 0xffffffffu;
|
||||||
|
for (int offset = 1; offset < WARP_SIZE; offset *= 2) {
|
||||||
|
v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, offset));
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ __nv_bfloat16 warp_max(__nv_bfloat16 v) {
|
||||||
|
static constexpr unsigned FULL_MASK = 0xffffffffu;
|
||||||
|
for (int offset = 1; offset < WARP_SIZE; offset *= 2) {
|
||||||
|
v = __hmax(v, __shfl_xor_sync(FULL_MASK, v, offset));
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
__device__ __forceinline__ void cp_async4(T* _smem_ptr, const U* _glob_ptr) {
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||||
|
auto smem_ptr = reinterpret_cast<void*>(_smem_ptr);
|
||||||
|
auto glob_ptr = reinterpret_cast<const void*>(_glob_ptr);
|
||||||
|
const int BYTES = 16;
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||||
|
"}\n" ::"r"(smem),
|
||||||
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
|
#else
|
||||||
|
_smem_ptr[0] = _glob_ptr[0];
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void cp_async_fence() {
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||||
|
asm volatile("cp.async.commit_group;\n" ::);
|
||||||
|
#else
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
__device__ __forceinline__ void cp_async_wait() {
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||||
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
|
||||||
|
#else
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __forceinline__ void cp_async_wait<0>() {
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||||
|
asm volatile("cp.async.wait_all;\n" ::);
|
||||||
|
#else
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ float clip(float v, float mmin, float mmax) {
|
||||||
|
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||||
|
return fminf(mmax, fmaxf(v, mmin));
|
||||||
|
#else
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v,
|
||||||
|
__nv_bfloat16 mmin,
|
||||||
|
__nv_bfloat16 mmax) {
|
||||||
|
return __hmin(mmax, __hmax(v, mmin));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ __nv_bfloat162 clip(__nv_bfloat162 v,
|
||||||
|
__nv_bfloat162 mmin,
|
||||||
|
__nv_bfloat162 mmax) {
|
||||||
|
return __hmin2(mmax, __hmax2(v, mmin));
|
||||||
|
}
|
||||||
|
|
||||||
|
// We use the following values for fp8 min/max:
|
||||||
|
// __nv_fp8_e4m3 = (-448, +448)
|
||||||
|
// __nv_fp8_e4m3uz = (-240.0, +240.0)
|
||||||
|
// It is currently assumed that only
|
||||||
|
template <class T>
|
||||||
|
constexpr __nv_bfloat16 get_fp8_max() {
|
||||||
|
static_assert(std::is_same_v<T, c10::Float8_e4m3fn> ||
|
||||||
|
std::is_same_v<T, c10::Float8_e4m3fnuz>);
|
||||||
|
if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
|
||||||
|
return __nv_bfloat16(__nv_bfloat16_raw{.x = 17376});
|
||||||
|
} else {
|
||||||
|
return __nv_bfloat16(__nv_bfloat16_raw{.x = 17264});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
constexpr __nv_bfloat16 get_fp8_min() {
|
||||||
|
static_assert(std::is_same_v<T, c10::Float8_e4m3fn> ||
|
||||||
|
std::is_same_v<T, c10::Float8_e4m3fnuz>);
|
||||||
|
if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
|
||||||
|
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50144});
|
||||||
|
} else {
|
||||||
|
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
template <typename fp8_type, int32_t NUM_WARPS, typename Idx_t,
|
||||||
|
int NUM_PARALLEL_TOKENS, bool USE_UE8M0, int GROUP_SIZE = 128,
|
||||||
|
int NUM_STAGES = 3>
|
||||||
|
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
||||||
|
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
|
||||||
|
float* __restrict__ _y_s, const int32_t* __restrict__ counts,
|
||||||
|
|
||||||
|
// sizes
|
||||||
|
int H, int G,
|
||||||
|
|
||||||
|
// strides (in elements)
|
||||||
|
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
|
||||||
|
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
|
||||||
|
Idx_t stride_ys_g, Idx_t stride_counts_e) {
|
||||||
|
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
|
||||||
|
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
|
||||||
|
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
|
||||||
|
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
|
||||||
|
|
||||||
|
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
|
||||||
|
static constexpr int32_t BFLOAT16_PER_GROUP = 8;
|
||||||
|
|
||||||
|
// We split the shared memory in half, corresponding to gate and up matrices:
|
||||||
|
// [...gate_i, ...up_i] where 0 <= i < stages.
|
||||||
|
static constexpr int32_t S_NUM_128 =
|
||||||
|
2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES;
|
||||||
|
static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE;
|
||||||
|
static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2;
|
||||||
|
static constexpr int32_t S_NUM_64 = S_NUM_128 * 2;
|
||||||
|
__shared__ __int128_t __align__(16) s_buff_128[S_NUM_128];
|
||||||
|
|
||||||
|
const int32_t tid = threadIdx.x;
|
||||||
|
const int32_t warp_id = tid / WARP_SIZE;
|
||||||
|
const int32_t lane_id = tid % WARP_SIZE;
|
||||||
|
|
||||||
|
auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128);
|
||||||
|
|
||||||
|
// block handles one (expert e, group g)
|
||||||
|
int32_t pid = blockIdx.x;
|
||||||
|
int32_t e = pid / G;
|
||||||
|
int32_t g = pid % G;
|
||||||
|
|
||||||
|
const int32_t n_tokens = counts[e * stride_counts_e];
|
||||||
|
|
||||||
|
if (!n_tokens) {
|
||||||
|
return; // Exit ASAP.
|
||||||
|
}
|
||||||
|
|
||||||
|
const Idx_t stride_i_t_128 = stride_i_t / 8u;
|
||||||
|
|
||||||
|
int32_t n_tokens_lower, n_tokens_upper;
|
||||||
|
|
||||||
|
// Each block i iterates over tokens of a slice of n_tokens =
|
||||||
|
// expert_counts[i], with the size of chunk being
|
||||||
|
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
|
||||||
|
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
|
||||||
|
if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) {
|
||||||
|
// Specialize this, but can be likely fused.
|
||||||
|
if (blockIdx.y >= NUM_PARALLEL_TOKENS) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
n_tokens_lower = blockIdx.y;
|
||||||
|
n_tokens_upper = blockIdx.y + 1;
|
||||||
|
} else {
|
||||||
|
auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS;
|
||||||
|
auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS;
|
||||||
|
auto calc_id = [&](int32_t id) {
|
||||||
|
if (id < residual) {
|
||||||
|
return min(n_tokens, id * (chunk_size + 1));
|
||||||
|
} else {
|
||||||
|
return min(n_tokens, id * chunk_size + residual);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
n_tokens_lower = calc_id(blockIdx.y);
|
||||||
|
n_tokens_upper = calc_id(blockIdx.y + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_tokens_lower >= n_tokens_upper) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We do calculations here, using constexpr wherever possible.
|
||||||
|
const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h;
|
||||||
|
const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g;
|
||||||
|
const Idx_t base_yq =
|
||||||
|
e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h;
|
||||||
|
Idx_t gate_off_128 = (base_i / static_cast<Idx_t>(8u));
|
||||||
|
auto input_128_ptr = reinterpret_cast<const __int128_t*>(_input);
|
||||||
|
auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) +
|
||||||
|
stride_i_t_128 * n_tokens_lower;
|
||||||
|
auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u;
|
||||||
|
auto y_s_ptr =
|
||||||
|
_y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t;
|
||||||
|
auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE +
|
||||||
|
stride_yq_t * n_tokens_lower + 4 * lane_id;
|
||||||
|
int32_t t_load = n_tokens_lower, load_stage_id = 0;
|
||||||
|
auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT);
|
||||||
|
auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u;
|
||||||
|
int32_t stage_offset{};
|
||||||
|
|
||||||
|
static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2);
|
||||||
|
static constexpr int32_t LOAD_STAGE_MOD =
|
||||||
|
NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2);
|
||||||
|
|
||||||
|
// Two halves of all threads in a block conduct global loads for gate and up,
|
||||||
|
// repsectively.
|
||||||
|
auto load_and_advance_y_pred = [&] {
|
||||||
|
if (t_load < n_tokens_upper) {
|
||||||
|
auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset;
|
||||||
|
auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset;
|
||||||
|
|
||||||
|
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
|
||||||
|
// unnecessary ALU ops.
|
||||||
|
stage_offset += LOAD_STAGE_SIZE;
|
||||||
|
stage_offset %= LOAD_STAGE_MOD;
|
||||||
|
|
||||||
|
if (tid < HALF_THREAD_COUNT) {
|
||||||
|
cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr);
|
||||||
|
gate_128_ptr += stride_i_t_128;
|
||||||
|
} else {
|
||||||
|
cp_async4(s_up_stage_128_staged_ptr, up_128_ptr);
|
||||||
|
up_128_ptr += stride_i_t_128;
|
||||||
|
}
|
||||||
|
++t_load;
|
||||||
|
++load_stage_id;
|
||||||
|
}
|
||||||
|
// We fence even if there is nothing to load to simplify pipelining.
|
||||||
|
cp_async_fence();
|
||||||
|
};
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_STAGES - 1; i++) {
|
||||||
|
load_and_advance_y_pred();
|
||||||
|
}
|
||||||
|
|
||||||
|
__int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>(
|
||||||
|
s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) +
|
||||||
|
lane_id;
|
||||||
|
__int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2;
|
||||||
|
|
||||||
|
static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u;
|
||||||
|
static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES;
|
||||||
|
|
||||||
|
int32_t compute_pipeline_offset_64 = 0;
|
||||||
|
|
||||||
|
for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) {
|
||||||
|
__nv_bfloat16 y_max_bf16 = EPS;
|
||||||
|
__nv_bfloat162 results_bf162[2];
|
||||||
|
|
||||||
|
cp_async_wait<NUM_STAGES - 2>();
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// We double-buffer pipelined loads so that the next load will
|
||||||
|
// concurrently run with compute without overwrites.
|
||||||
|
load_and_advance_y_pred();
|
||||||
|
|
||||||
|
auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64;
|
||||||
|
auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64;
|
||||||
|
|
||||||
|
// STAGE_SIZE must also be constexpr!
|
||||||
|
compute_pipeline_offset_64 += STAGE_SIZE;
|
||||||
|
compute_pipeline_offset_64 %= STAGE_MOD;
|
||||||
|
|
||||||
|
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
|
||||||
|
__int64_t gate64 = *s_gate_compute_64;
|
||||||
|
__nv_bfloat162* s_gate_compute_32 =
|
||||||
|
reinterpret_cast<__nv_bfloat162*>(&gate64);
|
||||||
|
|
||||||
|
__int64_t up64 = *s_up_compute_64;
|
||||||
|
__nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
// For silu, we make sure that div is emitted.
|
||||||
|
float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i]));
|
||||||
|
results_bf162[i] = __float22bfloat162_rn(gate);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto _y_max2 =
|
||||||
|
__hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1]));
|
||||||
|
|
||||||
|
y_max_bf16 = __hmax(_y_max2.x, _y_max2.y);
|
||||||
|
|
||||||
|
// An entire group is assigned to a single warp, so a simple warp reduce
|
||||||
|
// is used.
|
||||||
|
__nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max;
|
||||||
|
|
||||||
|
if constexpr (USE_UE8M0) {
|
||||||
|
y_s = hexp2(hceil(hlog2(y_s)));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inv_y = __float2bfloat16_rn(1.f) / y_s;
|
||||||
|
|
||||||
|
auto y_s2 = make_bfloat162(inv_y, inv_y);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int32_t i = 0; i < 2; ++i) {
|
||||||
|
results_bf162[i] =
|
||||||
|
clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min),
|
||||||
|
__bfloat162bfloat162(fp8_max));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]);
|
||||||
|
*reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4;
|
||||||
|
y_q_ptr += stride_yq_t;
|
||||||
|
|
||||||
|
if (lane_id == 0) {
|
||||||
|
*y_s_ptr = y_s;
|
||||||
|
y_s_ptr += stride_ys_t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// Launch activation, gating, and quantize kernel.
|
// Launch activation, gating, and quantize kernel.
|
||||||
@ -119,3 +470,117 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
|||||||
TORCH_CHECK(input.size(-1) % 2 == 0);
|
TORCH_CHECK(input.size(-1) % 2 == 0);
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void silu_mul_fp8_quant_deep_gemm_cuda(
|
||||||
|
const at::Tensor& input, // (E, T, 2*H)
|
||||||
|
const at::Tensor& counts, // (E)
|
||||||
|
at::Tensor& y_q, // (E, T, H) [OUT]
|
||||||
|
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||||
|
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
// This kernel relies heavily on cp.async and fp8 support.
|
||||||
|
// This kernel currently only supports H % 128 == 0 and assumes a
|
||||||
|
// fixed GROUP_SIZE of 128.
|
||||||
|
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
||||||
|
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
|
||||||
|
y_q.dtype() == torch::kFloat8_e4m3fnuz);
|
||||||
|
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(input.size(-1) % 256 == 0);
|
||||||
|
|
||||||
|
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
|
||||||
|
TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64);
|
||||||
|
TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1)));
|
||||||
|
|
||||||
|
using Idx_t = int64_t;
|
||||||
|
|
||||||
|
Idx_t E = input.size(0);
|
||||||
|
Idx_t T = input.size(1);
|
||||||
|
Idx_t H = input.size(2) / 2;
|
||||||
|
Idx_t stride_i_e = input.stride(0);
|
||||||
|
Idx_t stride_i_t = input.stride(1);
|
||||||
|
Idx_t stride_i_h = input.stride(2);
|
||||||
|
Idx_t stride_yq_e = y_q.stride(0);
|
||||||
|
Idx_t stride_yq_t = y_q.stride(1);
|
||||||
|
Idx_t stride_yq_h = y_q.stride(2);
|
||||||
|
Idx_t stride_ys_e = y_s.stride(0);
|
||||||
|
Idx_t stride_ys_t = y_s.stride(1);
|
||||||
|
Idx_t stride_ys_g = y_s.stride(2);
|
||||||
|
|
||||||
|
Idx_t stride_counts_e = counts.stride(0);
|
||||||
|
|
||||||
|
static constexpr int GROUP_SIZE = 128;
|
||||||
|
|
||||||
|
#define KERNEL_FN \
|
||||||
|
if (use_ue8m0) { \
|
||||||
|
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
|
||||||
|
NUM_PARALLEL_TOKENS, true> \
|
||||||
|
<<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
||||||
|
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
||||||
|
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
|
||||||
|
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
|
||||||
|
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
|
||||||
|
stride_counts_e); \
|
||||||
|
} else { \
|
||||||
|
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
|
||||||
|
NUM_PARALLEL_TOKENS, false> \
|
||||||
|
<<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
||||||
|
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
||||||
|
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
|
||||||
|
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
|
||||||
|
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
|
||||||
|
stride_counts_e); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define KERNEL_CALL_H \
|
||||||
|
if (H % (4 * GROUP_SIZE) == 0) { \
|
||||||
|
static constexpr int NUM_WARPS = 4; \
|
||||||
|
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
|
||||||
|
KERNEL_FN \
|
||||||
|
} else { \
|
||||||
|
static constexpr int NUM_WARPS = 1; \
|
||||||
|
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
|
||||||
|
KERNEL_FN \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define KERNEL_CALL_TOP_LEVEL \
|
||||||
|
if (num_parallel_tokens == 1) { \
|
||||||
|
static constexpr int NUM_PARALLEL_TOKENS = 1; \
|
||||||
|
KERNEL_CALL_H \
|
||||||
|
} else if (num_parallel_tokens == 2) { \
|
||||||
|
static constexpr int NUM_PARALLEL_TOKENS = 2; \
|
||||||
|
KERNEL_CALL_H \
|
||||||
|
} else if (num_parallel_tokens == 4) { \
|
||||||
|
static constexpr int NUM_PARALLEL_TOKENS = 4; \
|
||||||
|
KERNEL_CALL_H \
|
||||||
|
} else if (num_parallel_tokens == 8) { \
|
||||||
|
static constexpr int NUM_PARALLEL_TOKENS = 8; \
|
||||||
|
KERNEL_CALL_H \
|
||||||
|
} else if (num_parallel_tokens == 16) { \
|
||||||
|
static constexpr int NUM_PARALLEL_TOKENS = 16; \
|
||||||
|
KERNEL_CALL_H \
|
||||||
|
} else if (num_parallel_tokens == 32) { \
|
||||||
|
static constexpr int NUM_PARALLEL_TOKENS = 32; \
|
||||||
|
KERNEL_CALL_H \
|
||||||
|
} else if (num_parallel_tokens == 64) { \
|
||||||
|
static constexpr int NUM_PARALLEL_TOKENS = 64; \
|
||||||
|
KERNEL_CALL_H \
|
||||||
|
}
|
||||||
|
|
||||||
|
Idx_t G;
|
||||||
|
dim3 block, grid;
|
||||||
|
auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) {
|
||||||
|
G = H / Idx_t(group_size * num_warps);
|
||||||
|
grid = dim3(E * G, _num_parallel_tokens);
|
||||||
|
block = dim3(num_warps * WARP_SIZE);
|
||||||
|
};
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(),
|
||||||
|
"silu_mul_fp8_quant_deep_gemm_kernel",
|
||||||
|
[&] { KERNEL_CALL_TOP_LEVEL });
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|||||||
@ -5,7 +5,9 @@
|
|||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
#include "nvidia/quant_utils.cuh"
|
||||||
|
#else
|
||||||
#include "amd/quant_utils.cuh"
|
#include "amd/quant_utils.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -48,7 +50,9 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
|||||||
float r =
|
float r =
|
||||||
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
|
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
return static_cast<fp8_type>(r);
|
// Use hardware cvt instruction for fp8 on nvidia
|
||||||
|
// Currently only support fp8_type = c10::Float8_e4m3fn
|
||||||
|
return fp8::vec_conversion<fp8_type, float>(r);
|
||||||
#else
|
#else
|
||||||
// Use hardware cvt instruction for fp8 on rocm
|
// Use hardware cvt instruction for fp8 on rocm
|
||||||
return fp8::cvt_c10<fp8_type>(r);
|
return fp8::cvt_c10<fp8_type>(r);
|
||||||
|
|||||||
@ -12,13 +12,26 @@ namespace vllm {
|
|||||||
namespace fp8 {
|
namespace fp8 {
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
|
|
||||||
#if 0 // Disable the following code to reduce the binary size.
|
|
||||||
template <typename Tout, typename Tin>
|
template <typename Tout, typename Tin>
|
||||||
__inline__ __device__ Tout
|
__inline__ __device__ Tout vec_conversion(
|
||||||
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
const Tin& x, const __nv_fp8_interpretation_t fp8_type = __NV_E4M3) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// float -> c10::Float8_e4m3fn
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ c10::Float8_e4m3fn
|
||||||
|
vec_conversion<c10::Float8_e4m3fn, float>(
|
||||||
|
const float& a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
return static_cast<c10::Float8_e4m3fn>(a);
|
||||||
|
#else
|
||||||
|
return c10::Float8_e4m3fn(__nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type),
|
||||||
|
c10::Float8_e4m3fn::from_bits());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 0 // Disable the following code to reduce the binary size.
|
||||||
// fp8 -> half
|
// fp8 -> half
|
||||||
template <>
|
template <>
|
||||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
||||||
|
|||||||
@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
#define stride_tag
|
#define stride_tag
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! "
|
||||||
|
"y_q, Tensor! y_s, int group_size, "
|
||||||
|
"bool use_ue8m0, int num_parallel_tokens) -> ()");
|
||||||
|
ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA,
|
||||||
|
&silu_mul_fp8_quant_deep_gemm_cuda);
|
||||||
|
|
||||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||||
|
|
||||||
@ -214,16 +221,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||||
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
||||||
|
|
||||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
|
|
||||||
// (supports multiple loras).
|
|
||||||
ops.def(
|
|
||||||
"batched_rotary_embedding(Tensor positions, Tensor! query,"
|
|
||||||
" Tensor!? key, int head_size,"
|
|
||||||
" Tensor cos_sin_cache, bool is_neox,"
|
|
||||||
" int rot_dim,"
|
|
||||||
" Tensor cos_sin_cache_offsets) -> ()");
|
|
||||||
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
|
|
||||||
|
|
||||||
// Quantization ops
|
// Quantization ops
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// Quantized GEMM for AWQ.
|
// Quantized GEMM for AWQ.
|
||||||
|
|||||||
@ -56,7 +56,7 @@ vLLM is flexible and easy to use with:
|
|||||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||||
- Streaming outputs
|
- Streaming outputs
|
||||||
- OpenAI-compatible API server
|
- OpenAI-compatible API server
|
||||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
|
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
|
||||||
- Prefix caching support
|
- Prefix caching support
|
||||||
- Multi-LoRA support
|
- Multi-LoRA support
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,6 @@
|
|||||||
|
|
||||||
vLLM's examples are split into three categories:
|
vLLM's examples are split into three categories:
|
||||||
|
|
||||||
- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference)
|
- If you are using vLLM from within Python code, see the *Offline Inference* section.
|
||||||
- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving)
|
- If you are using vLLM from an HTTP application or client, see the *Online Serving* section.
|
||||||
- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others)
|
- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see the *Others* section.
|
||||||
|
|||||||
@ -76,6 +76,3 @@ th:not(:first-child) {
|
|||||||
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8477) | ✅ | ❌ |
|
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8477) | ✅ | ❌ |
|
||||||
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
|
||||||
!!! note
|
|
||||||
Please refer to [Feature support through NxD Inference backend][feature-support-through-nxd-inference-backend] for features supported on AWS Neuron hardware
|
|
||||||
|
|||||||
@ -45,6 +45,32 @@ When using multi-modal inputs, vLLM normally hashes each media item by content t
|
|||||||
print(o.outputs[0].text)
|
print(o.outputs[0].text)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Using UUIDs, you can also skip sending media data entirely if you expect cache hits for respective items. Note that the request will fail if the skipped media doesn't have a corresponding UUID, or if the UUID fails to hit the cache.
|
||||||
|
|
||||||
|
??? code
|
||||||
|
|
||||||
|
```python
|
||||||
|
from vllm import LLM
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Qwen2.5-VL example with two images
|
||||||
|
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
|
||||||
|
|
||||||
|
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
|
||||||
|
img_b = Image.open("/path/to/b.jpg")
|
||||||
|
|
||||||
|
outputs = llm.generate({
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {"image": [None, img_b]},
|
||||||
|
# Since img_a is expected to be cached, we can skip sending the actual
|
||||||
|
# image entirely.
|
||||||
|
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
|
||||||
|
})
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
print(o.outputs[0].text)
|
||||||
|
```
|
||||||
|
|
||||||
!!! warning
|
!!! warning
|
||||||
If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored.
|
If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored.
|
||||||
|
|
||||||
@ -755,6 +781,39 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For Online Serving, you can also skip sending media if you expect cache hits with provided UUIDs. You can do so by sending media like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Image/video/audio URL:
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": None,
|
||||||
|
"uuid": image_uuid,
|
||||||
|
},
|
||||||
|
|
||||||
|
# image_embeds
|
||||||
|
{
|
||||||
|
"type": "image_embeds",
|
||||||
|
"image_embeds": None,
|
||||||
|
"uuid": image_uuid
|
||||||
|
},
|
||||||
|
|
||||||
|
# input_audio:
|
||||||
|
{
|
||||||
|
"type": "input_audio",
|
||||||
|
"input_audio": None,
|
||||||
|
"uuid": audio_uuid
|
||||||
|
},
|
||||||
|
|
||||||
|
# PIL Image:
|
||||||
|
{
|
||||||
|
"type": "image_pil",
|
||||||
|
"image_pil": None
|
||||||
|
"uuid": image_uuid
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
Only one message can contain `{"type": "image_embeds"}`.
|
Only one message can contain `{"type": "image_embeds"}`.
|
||||||
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.
|
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.
|
||||||
|
|||||||
@ -43,19 +43,19 @@ th:not(:first-child) {
|
|||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|
||||||
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU |
|
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | Google TPU |
|
||||||
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------|
|
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|
|
||||||
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ |
|
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ |
|
||||||
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ |
|
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ |
|
||||||
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
|
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ |
|
||||||
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ |
|
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ |
|
| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ |
|
||||||
|
|
||||||
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
|
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
|
||||||
- ✅︎ indicates that the quantization method is supported on the specified hardware.
|
- ✅︎ indicates that the quantization method is supported on the specified hardware.
|
||||||
|
|||||||
@ -3,5 +3,3 @@ nav:
|
|||||||
- gpu.md
|
- gpu.md
|
||||||
- cpu.md
|
- cpu.md
|
||||||
- google_tpu.md
|
- google_tpu.md
|
||||||
- intel_gaudi.md
|
|
||||||
- aws_neuron.md
|
|
||||||
|
|||||||
@ -12,7 +12,6 @@ vLLM supports the following hardware platforms:
|
|||||||
- [Apple silicon](cpu.md#apple-silicon)
|
- [Apple silicon](cpu.md#apple-silicon)
|
||||||
- [IBM Z (S390X)](cpu.md#ibm-z-s390x)
|
- [IBM Z (S390X)](cpu.md#ibm-z-s390x)
|
||||||
- [Google TPU](google_tpu.md)
|
- [Google TPU](google_tpu.md)
|
||||||
- [AWS Neuron](aws_neuron.md)
|
|
||||||
|
|
||||||
## Hardware Plugins
|
## Hardware Plugins
|
||||||
|
|
||||||
|
|||||||
@ -1,147 +0,0 @@
|
|||||||
# AWS Neuron
|
|
||||||
|
|
||||||
[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/) is the software development kit (SDK) used to run deep learning and
|
|
||||||
generative AI workloads on AWS Inferentia and AWS Trainium powered Amazon EC2 instances and UltraServers (Inf1, Inf2, Trn1, Trn2,
|
|
||||||
and Trn2 UltraServer). Both Trainium and Inferentia are powered by fully-independent heterogeneous compute-units called NeuronCores.
|
|
||||||
This describes how to set up your environment to run vLLM on Neuron.
|
|
||||||
|
|
||||||
!!! warning
|
|
||||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
- OS: Linux
|
|
||||||
- Python: 3.9 or newer
|
|
||||||
- Pytorch 2.5/2.6
|
|
||||||
- Accelerator: NeuronCore-v2 (in trn1/inf2 chips) or NeuronCore-v3 (in trn2 chips)
|
|
||||||
- AWS Neuron SDK 2.23
|
|
||||||
|
|
||||||
## Configure a new environment
|
|
||||||
|
|
||||||
### Launch a Trn1/Trn2/Inf2 instance and verify Neuron dependencies
|
|
||||||
|
|
||||||
The easiest way to launch a Trainium or Inferentia instance with pre-installed Neuron dependencies is to follow this
|
|
||||||
[quick start guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/neuron-setup/multiframework/multi-framework-ubuntu22-neuron-dlami.html#setup-ubuntu22-multi-framework-dlami) using the Neuron Deep Learning AMI (Amazon machine image).
|
|
||||||
|
|
||||||
- After launching the instance, follow the instructions in [Connect to your instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AccessingInstancesLinux.html) to connect to the instance
|
|
||||||
- Once inside your instance, activate the pre-installed virtual environment for inference by running
|
|
||||||
|
|
||||||
```bash
|
|
||||||
source /opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/bin/activate
|
|
||||||
```
|
|
||||||
|
|
||||||
Refer to the [NxD Inference Setup Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/nxdi-setup.html)
|
|
||||||
for alternative setup instructions including using Docker and manually installing dependencies.
|
|
||||||
|
|
||||||
!!! note
|
|
||||||
NxD Inference is the default recommended backend to run inference on Neuron. If you are looking to use the legacy [transformers-neuronx](https://github.com/aws-neuron/transformers-neuronx)
|
|
||||||
library, refer to [Transformers NeuronX Setup](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/setup/index.html).
|
|
||||||
|
|
||||||
## Set up using Python
|
|
||||||
|
|
||||||
### Pre-built wheels
|
|
||||||
|
|
||||||
Currently, there are no pre-built Neuron wheels.
|
|
||||||
|
|
||||||
### Build wheel from source
|
|
||||||
|
|
||||||
To build and install vLLM from source, run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/vllm-project/vllm.git
|
|
||||||
cd vllm
|
|
||||||
pip install -U -r requirements/neuron.txt
|
|
||||||
VLLM_TARGET_DEVICE="neuron" pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
AWS Neuron maintains a [Github fork of vLLM](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2) at
|
|
||||||
<https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2>, which contains several features in addition to what's
|
|
||||||
available on vLLM V0. Please utilize the AWS Fork for the following features:
|
|
||||||
|
|
||||||
- Llama-3.2 multi-modal support
|
|
||||||
- Multi-node distributed inference
|
|
||||||
|
|
||||||
Refer to [vLLM User Guide for NxD Inference](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/vllm-user-guide.html)
|
|
||||||
for more details and usage examples.
|
|
||||||
|
|
||||||
To install the AWS Neuron fork, run the following:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone -b neuron-2.23-vllm-v0.7.2 https://github.com/aws-neuron/upstreaming-to-vllm.git
|
|
||||||
cd upstreaming-to-vllm
|
|
||||||
pip install -r requirements/neuron.txt
|
|
||||||
VLLM_TARGET_DEVICE="neuron" pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that the AWS Neuron fork is only intended to support Neuron hardware; compatibility with other hardwares is not tested.
|
|
||||||
|
|
||||||
## Set up using Docker
|
|
||||||
|
|
||||||
### Pre-built images
|
|
||||||
|
|
||||||
Currently, there are no pre-built Neuron images.
|
|
||||||
|
|
||||||
### Build image from source
|
|
||||||
|
|
||||||
See [deployment-docker-build-image-from-source][deployment-docker-build-image-from-source] for instructions on building the Docker image.
|
|
||||||
|
|
||||||
Make sure to use <gh-file:docker/Dockerfile.neuron> in place of the default Dockerfile.
|
|
||||||
|
|
||||||
## Extra information
|
|
||||||
|
|
||||||
[](){ #feature-support-through-nxd-inference-backend }
|
|
||||||
|
|
||||||
### Feature support through NxD Inference backend
|
|
||||||
|
|
||||||
The current vLLM and Neuron integration relies on either the `neuronx-distributed-inference` (preferred) or `transformers-neuronx` backend
|
|
||||||
to perform most of the heavy lifting which includes PyTorch model initialization, compilation, and runtime execution. Therefore, most
|
|
||||||
[features supported on Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html) are also available via the vLLM integration.
|
|
||||||
|
|
||||||
To configure NxD Inference features through the vLLM entrypoint, use the `override_neuron_config` setting. Provide the configs you want to override
|
|
||||||
as a dictionary (or JSON object when starting vLLM from the CLI). For example, to disable auto bucketing, include
|
|
||||||
|
|
||||||
```python
|
|
||||||
override_neuron_config={
|
|
||||||
"enable_bucketing":False,
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
or when launching vLLM from the CLI, pass
|
|
||||||
|
|
||||||
```bash
|
|
||||||
--override-neuron-config "{\"enable_bucketing\":false}"
|
|
||||||
```
|
|
||||||
|
|
||||||
Alternatively, users can directly call the NxDI library to trace and compile your model, then load the pre-compiled artifacts
|
|
||||||
(via `NEURON_COMPILED_ARTIFACTS` environment variable) in vLLM to run inference workloads.
|
|
||||||
|
|
||||||
### Known limitations
|
|
||||||
|
|
||||||
- EAGLE speculative decoding: NxD Inference requires the EAGLE draft checkpoint to include the LM head weights from the target model. Refer to this
|
|
||||||
[guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html#eagle-checkpoint-compatibility)
|
|
||||||
for how to convert pretrained EAGLE model checkpoints to be compatible for NxDI.
|
|
||||||
- Quantization: the native quantization flow in vLLM is not well supported on NxD Inference. It is recommended to follow this
|
|
||||||
[Neuron quantization guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/custom-quantization.html)
|
|
||||||
to quantize and compile your model using NxD Inference, and then load the compiled artifacts into vLLM.
|
|
||||||
- Multi-LoRA serving: NxD Inference only supports loading of LoRA adapters at server startup. Dynamic loading of LoRA adapters at
|
|
||||||
runtime is not currently supported. Refer to [multi-lora example](https://github.com/aws-neuron/upstreaming-to-vllm/blob/neuron-2.23-vllm-v0.7.2/examples/offline_inference/neuron_multi_lora.py)
|
|
||||||
- Multi-modal support: multi-modal support is only available through the AWS Neuron fork. This feature has not been upstreamed
|
|
||||||
to vLLM main because NxD Inference currently relies on certain adaptations to the core vLLM logic to support this feature.
|
|
||||||
- Multi-node support: distributed inference across multiple Trainium/Inferentia instances is only supported on the AWS Neuron fork. Refer
|
|
||||||
to this [multi-node example](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2/examples/neuron/multi_node)
|
|
||||||
to run. Note that tensor parallelism (distributed inference across NeuronCores) is available in vLLM main.
|
|
||||||
- Known edge case bug in speculative decoding: An edge case failure may occur in speculative decoding when sequence length approaches
|
|
||||||
max model length (e.g. when requesting max tokens up to the max model length and ignoring eos). In this scenario, vLLM may attempt
|
|
||||||
to allocate an additional block to ensure there is enough memory for number of lookahead slots, but since we do not have good support
|
|
||||||
for paged attention, there isn't another Neuron block for vLLM to allocate. A workaround fix (to terminate 1 iteration early) is
|
|
||||||
implemented in the AWS Neuron fork but is not upstreamed to vLLM main as it modifies core vLLM logic.
|
|
||||||
|
|
||||||
### Environment variables
|
|
||||||
|
|
||||||
- `NEURON_COMPILED_ARTIFACTS`: set this environment variable to point to your pre-compiled model artifacts directory to avoid
|
|
||||||
compilation time upon server initialization. If this variable is not set, the Neuron module will perform compilation and save the
|
|
||||||
artifacts under `neuron-compiled-artifacts/{unique_hash}/` subdirectory in the model path. If this environment variable is set,
|
|
||||||
but the directory does not exist, or the contents are invalid, Neuron will also fall back to a new compilation and store the artifacts
|
|
||||||
under this specified path.
|
|
||||||
- `NEURON_CONTEXT_LENGTH_BUCKETS`: Bucket sizes for context encoding. (Only applicable to `transformers-neuronx` backend).
|
|
||||||
- `NEURON_TOKEN_GEN_BUCKETS`: Bucket sizes for token generation. (Only applicable to `transformers-neuronx` backend).
|
|
||||||
@ -389,6 +389,7 @@ th {
|
|||||||
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
|
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
|
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
|
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -42,7 +42,7 @@ def main():
|
|||||||
llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct"
|
llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
|
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
|
||||||
# In real workloads, `enforace_eager` should be `False`.
|
# In real workloads, `enforce_eager` should be `False`.
|
||||||
llm = LLM(**llm_args)
|
llm = LLM(**llm_args)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|||||||
@ -1764,6 +1764,7 @@ def apply_image_repeat(
|
|||||||
probs = [1.0 - image_repeat_prob, image_repeat_prob]
|
probs = [1.0 - image_repeat_prob, image_repeat_prob]
|
||||||
|
|
||||||
inputs = []
|
inputs = []
|
||||||
|
inputs_with_empty_media = []
|
||||||
cur_image = data
|
cur_image = data
|
||||||
for i in range(num_prompts):
|
for i in range(num_prompts):
|
||||||
if image_repeat_prob is not None:
|
if image_repeat_prob is not None:
|
||||||
@ -1774,14 +1775,25 @@ def apply_image_repeat(
|
|||||||
new_val = (i // 256 // 256, i // 256, i % 256)
|
new_val = (i // 256 // 256, i // 256, i % 256)
|
||||||
cur_image.putpixel((0, 0), new_val)
|
cur_image.putpixel((0, 0), new_val)
|
||||||
|
|
||||||
|
uuid = "uuid_{}".format(i)
|
||||||
|
|
||||||
inputs.append(
|
inputs.append(
|
||||||
{
|
{
|
||||||
"prompt": prompts[i % len(prompts)],
|
"prompt": prompts[i % len(prompts)],
|
||||||
"multi_modal_data": {modality: cur_image},
|
"multi_modal_data": {modality: cur_image},
|
||||||
|
"multi_modal_uuids": {modality: uuid},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return inputs
|
inputs_with_empty_media.append(
|
||||||
|
{
|
||||||
|
"prompt": prompts[i % len(prompts)],
|
||||||
|
"multi_modal_data": {modality: None},
|
||||||
|
"multi_modal_uuids": {modality: uuid},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return inputs, inputs_with_empty_media
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -1860,6 +1872,13 @@ def parse_args():
|
|||||||
help="If True, then use different prompt (with the same multi-modal "
|
help="If True, then use different prompt (with the same multi-modal "
|
||||||
"data) for each request.",
|
"data) for each request.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--verify-mm-cache-hit-with-uuids",
|
||||||
|
action="store_true",
|
||||||
|
help="If True, will send all requests in a second batch with empty mm "
|
||||||
|
"data to verify cache hits with UUIDs.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -1903,26 +1922,48 @@ def main(args):
|
|||||||
assert args.num_prompts > 0
|
assert args.num_prompts > 0
|
||||||
if args.num_prompts == 1:
|
if args.num_prompts == 1:
|
||||||
# Single inference
|
# Single inference
|
||||||
|
uuid = "uuid_0"
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt": prompts[0],
|
"prompt": prompts[0],
|
||||||
"multi_modal_data": {modality: data},
|
"multi_modal_data": {modality: data},
|
||||||
|
"multi_modal_uuids": {modality: uuid},
|
||||||
|
}
|
||||||
|
inputs_with_empty_media = {
|
||||||
|
"prompt": prompts[0],
|
||||||
|
"multi_modal_data": {modality: None},
|
||||||
|
"multi_modal_uuids": {modality: uuid},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Batch inference
|
# Batch inference
|
||||||
if args.image_repeat_prob is not None:
|
if args.image_repeat_prob is not None:
|
||||||
# Repeat images with specified probability of "image_repeat_prob"
|
# Repeat images with specified probability of "image_repeat_prob"
|
||||||
inputs = apply_image_repeat(
|
inputs, inputs_with_empty_media = apply_image_repeat(
|
||||||
args.image_repeat_prob, args.num_prompts, data, prompts, modality
|
args.image_repeat_prob,
|
||||||
|
args.num_prompts,
|
||||||
|
data,
|
||||||
|
prompts,
|
||||||
|
modality,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the same image for all prompts
|
# Use the same image for all prompts
|
||||||
inputs = [
|
inputs = []
|
||||||
{
|
inputs_with_empty_media = []
|
||||||
"prompt": prompts[i % len(prompts)],
|
for i in range(args.num_prompts):
|
||||||
"multi_modal_data": {modality: data},
|
uuid = "uuid_{}".format(i)
|
||||||
}
|
inputs.append(
|
||||||
for i in range(args.num_prompts)
|
{
|
||||||
]
|
"prompt": prompts[i % len(prompts)],
|
||||||
|
"multi_modal_data": {modality: data},
|
||||||
|
"multi_modal_uuids": {modality: uuid},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
inputs_with_empty_media.append(
|
||||||
|
{
|
||||||
|
"prompt": prompts[i % len(prompts)],
|
||||||
|
"multi_modal_data": {modality: None},
|
||||||
|
"multi_modal_uuids": {modality: uuid},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Add LoRA request if applicable
|
# Add LoRA request if applicable
|
||||||
lora_request = (
|
lora_request = (
|
||||||
@ -1942,6 +1983,26 @@ def main(args):
|
|||||||
print(generated_text)
|
print(generated_text)
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
if args.verify_mm_cache_hit_with_uuids:
|
||||||
|
try:
|
||||||
|
# Verify cache hits with UUIDs
|
||||||
|
print(
|
||||||
|
"Sending a second batch of requests with empty media"
|
||||||
|
" and matching UUIDs."
|
||||||
|
)
|
||||||
|
outputs = llm.generate(
|
||||||
|
inputs_with_empty_media,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
)
|
||||||
|
print("-" * 50)
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
print("-" * 50)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to verify cache hits with UUIDs. Error: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|||||||
@ -62,6 +62,8 @@ def _fix_prompt_embed_outputs(
|
|||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||||
@pytest.mark.parametrize("max_tokens", [5])
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
@pytest.mark.parametrize("enforce_eager", [False])
|
@pytest.mark.parametrize("enforce_eager", [False])
|
||||||
|
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||||
|
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
|
||||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||||
def test_models(
|
def test_models(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
@ -70,6 +72,8 @@ def test_models(
|
|||||||
backend: str,
|
backend: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
|
async_scheduling: bool,
|
||||||
|
model_executor: str,
|
||||||
enable_prompt_embeds: bool,
|
enable_prompt_embeds: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
@ -77,6 +81,12 @@ def test_models(
|
|||||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||||
|
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
if async_scheduling:
|
||||||
|
pytest.skip("async_scheduling only supported in v1.")
|
||||||
|
if model_executor != "uni":
|
||||||
|
pytest.skip("only test uniproc executor for v0.")
|
||||||
|
|
||||||
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
f"{backend} does not support gemma2 with full context length.")
|
f"{backend} does not support gemma2 with full context length.")
|
||||||
@ -98,11 +108,15 @@ def test_models(
|
|||||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
prompt_embeds = hf_model.get_prompt_embeddings(
|
||||||
example_prompts)
|
example_prompts)
|
||||||
|
|
||||||
with VllmRunner(model,
|
with VllmRunner(
|
||||||
max_model_len=8192,
|
model,
|
||||||
enforce_eager=enforce_eager,
|
max_model_len=8192,
|
||||||
enable_prompt_embeds=enable_prompt_embeds,
|
enforce_eager=enforce_eager,
|
||||||
gpu_memory_utilization=0.7) as vllm_model:
|
enable_prompt_embeds=enable_prompt_embeds,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
async_scheduling=async_scheduling,
|
||||||
|
distributed_executor_backend=model_executor,
|
||||||
|
) as vllm_model:
|
||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
vllm_outputs = vllm_model.generate_greedy(
|
vllm_outputs = vllm_model.generate_greedy(
|
||||||
prompt_embeds, max_tokens)
|
prompt_embeds, max_tokens)
|
||||||
|
|||||||
@ -522,6 +522,71 @@ async def test_completions_with_image_with_uuid(
|
|||||||
assert isinstance(chat_completion.choices[0].message.content, str)
|
assert isinstance(chat_completion.choices[0].message.content, str)
|
||||||
assert len(chat_completion.choices[0].message.content) > 0
|
assert len(chat_completion.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
# Second request, with empty image but the same uuid.
|
||||||
|
chat_completion_with_empty_image = await client.chat.completions.create(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this image.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {},
|
||||||
|
"uuid": image_url
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
assert chat_completion_with_empty_image.choices[
|
||||||
|
0].message.content is not None
|
||||||
|
assert isinstance(
|
||||||
|
chat_completion_with_empty_image.choices[0].message.content, str)
|
||||||
|
assert len(
|
||||||
|
chat_completion_with_empty_image.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_completions_with_empty_image_with_uuid_without_cache_hit(
|
||||||
|
client: openai.AsyncOpenAI,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
_ = await client.chat.completions.create(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this image.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {},
|
||||||
|
"uuid": "uuid_not_previously_seen"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
|||||||
@ -79,6 +79,28 @@ def phi3v_tokenizer():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def qwen2_audio_model_config():
|
||||||
|
return ModelConfig(
|
||||||
|
QWEN2AUDIO_MODEL_ID,
|
||||||
|
runner="generate",
|
||||||
|
trust_remote_code=True,
|
||||||
|
limit_mm_per_prompt={
|
||||||
|
"audio": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def qwen2_audio_tokenizer():
|
||||||
|
return TokenizerGroup(
|
||||||
|
tokenizer_id=QWEN2AUDIO_MODEL_ID,
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=5,
|
||||||
|
max_input_length=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def qwen25omni_model_config_mm_interleaved():
|
def qwen25omni_model_config_mm_interleaved():
|
||||||
return ModelConfig(
|
return ModelConfig(
|
||||||
@ -169,6 +191,7 @@ def audio_url():
|
|||||||
def _assert_mm_data_is_image_input(
|
def _assert_mm_data_is_image_input(
|
||||||
mm_data: Optional[MultiModalDataDict],
|
mm_data: Optional[MultiModalDataDict],
|
||||||
image_count: int,
|
image_count: int,
|
||||||
|
skipped_image_indices: Optional[list] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert mm_data is not None
|
assert mm_data is not None
|
||||||
assert set(mm_data.keys()) == {"image"}
|
assert set(mm_data.keys()) == {"image"}
|
||||||
@ -177,6 +200,9 @@ def _assert_mm_data_is_image_input(
|
|||||||
assert image_data is not None
|
assert image_data is not None
|
||||||
|
|
||||||
assert isinstance(image_data, list) and len(image_data) == image_count
|
assert isinstance(image_data, list) and len(image_data) == image_count
|
||||||
|
if skipped_image_indices is not None:
|
||||||
|
for i in skipped_image_indices:
|
||||||
|
assert image_data[i] is None
|
||||||
|
|
||||||
|
|
||||||
def _assert_mm_uuids(
|
def _assert_mm_uuids(
|
||||||
@ -205,8 +231,10 @@ MultiModalDataCounts = Mapping[ModalityType, int]
|
|||||||
|
|
||||||
|
|
||||||
def _assert_mm_data_inputs(
|
def _assert_mm_data_inputs(
|
||||||
mm_data: Optional[MultiModalDataDict],
|
mm_data: Optional[MultiModalDataDict],
|
||||||
data_count: MultiModalDataCounts,
|
data_count: MultiModalDataCounts,
|
||||||
|
skipped_media_indices: Optional[dict[
|
||||||
|
str, list]] = None, # modality -> list[int]
|
||||||
) -> None:
|
) -> None:
|
||||||
assert mm_data is not None
|
assert mm_data is not None
|
||||||
assert set(data_count.keys()) == (set(mm_data.keys()))
|
assert set(data_count.keys()) == (set(mm_data.keys()))
|
||||||
@ -216,6 +244,13 @@ def _assert_mm_data_inputs(
|
|||||||
assert modality_data is not None
|
assert modality_data is not None
|
||||||
assert isinstance(modality_data, list) and len(modality_data) == n
|
assert isinstance(modality_data, list) and len(modality_data) == n
|
||||||
|
|
||||||
|
if skipped_media_indices is not None:
|
||||||
|
skipped_media_indices_for_modality = skipped_media_indices.get(
|
||||||
|
modality)
|
||||||
|
assert skipped_media_indices_for_modality is not None
|
||||||
|
for i in skipped_media_indices_for_modality:
|
||||||
|
assert modality_data[i] is None
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_single_image(
|
def test_parse_chat_messages_single_image(
|
||||||
phi3v_model_config,
|
phi3v_model_config,
|
||||||
@ -289,6 +324,41 @@ def test_parse_chat_messages_single_image_with_uuid(
|
|||||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid])
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_single_empty_image_with_uuid(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
|
image_uuid = str(hash(image_url))
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": None,
|
||||||
|
"uuid": image_uuid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in the image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "<|image_1|>\nWhat's in the image?"
|
||||||
|
}]
|
||||||
|
_assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0])
|
||||||
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid])
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_single_image_with_bad_uuid_format(
|
def test_parse_chat_messages_single_image_with_bad_uuid_format(
|
||||||
phi3v_model_config,
|
phi3v_model_config,
|
||||||
phi3v_tokenizer,
|
phi3v_tokenizer,
|
||||||
@ -375,6 +445,96 @@ def test_parse_chat_messages_multiple_images_with_uuids(
|
|||||||
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2])
|
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_multiple_empty_images_with_uuids(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
|
image_uuid1 = "my_uuid_1"
|
||||||
|
image_uuid2 = "my_uuid_2"
|
||||||
|
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": None,
|
||||||
|
"uuid": image_uuid1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": None,
|
||||||
|
"uuid": image_uuid2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in the image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"<|image_1|>\n<|image_2|>\nWhat's in the image?",
|
||||||
|
}]
|
||||||
|
_assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[0, 1])
|
||||||
|
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_mixed_empty_images_with_uuids(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
|
image_uuid1 = "my_uuid_1"
|
||||||
|
image_uuid2 = "my_uuid_2"
|
||||||
|
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url,
|
||||||
|
},
|
||||||
|
"uuid": image_uuid1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": None,
|
||||||
|
"uuid": image_uuid2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in the image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"<|image_1|>\n<|image_2|>\nWhat's in the image?",
|
||||||
|
}]
|
||||||
|
_assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[1])
|
||||||
|
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_chat_messages_single_image_with_uuid_async(
|
async def test_parse_chat_messages_single_image_with_uuid_async(
|
||||||
phi3v_model_config,
|
phi3v_model_config,
|
||||||
@ -413,6 +573,44 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
|
|||||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid])
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_chat_messages_empty_image_with_uuid_async(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
|
image_uuid = str(hash(image_url))
|
||||||
|
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||||
|
[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": None,
|
||||||
|
"uuid": image_uuid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in the image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "<|image_1|>\nWhat's in the image?"
|
||||||
|
}]
|
||||||
|
_assert_mm_data_is_image_input(await mm_future,
|
||||||
|
1,
|
||||||
|
skipped_image_indices=[0])
|
||||||
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_chat_messages_multiple_images_with_uuids_async(
|
async def test_parse_chat_messages_multiple_images_with_uuids_async(
|
||||||
phi3v_model_config,
|
phi3v_model_config,
|
||||||
@ -460,6 +658,53 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
|
|||||||
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2])
|
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
image_url,
|
||||||
|
):
|
||||||
|
image_uuid1 = "my_uuid_1"
|
||||||
|
image_uuid2 = "my_uuid_2"
|
||||||
|
|
||||||
|
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||||
|
[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": None,
|
||||||
|
"uuid": image_uuid1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_pil",
|
||||||
|
"image_pil": None,
|
||||||
|
"uuid": image_uuid2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in these images?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"<|image_1|>\n<|image_2|>\nWhat's in these images?",
|
||||||
|
}]
|
||||||
|
_assert_mm_data_is_image_input(await mm_future,
|
||||||
|
2,
|
||||||
|
skipped_image_indices=[0, 1])
|
||||||
|
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
|
async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
|
||||||
phi3v_model_config,
|
phi3v_model_config,
|
||||||
@ -653,6 +898,114 @@ def test_parse_chat_messages_multiple_images(
|
|||||||
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
|
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_empty_pil_image_with_uuid(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
):
|
||||||
|
uuid = "abcd"
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_pil",
|
||||||
|
"image_pil": None,
|
||||||
|
"uuid": uuid
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "<|image_1|>\nWhat's in this image?",
|
||||||
|
}]
|
||||||
|
_assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0])
|
||||||
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
):
|
||||||
|
uuid = "abcd"
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_embeds",
|
||||||
|
"image_embeds": None,
|
||||||
|
"uuid": uuid
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "<|image_1|>\nWhat's in this image?",
|
||||||
|
}]
|
||||||
|
assert mm_data is not None
|
||||||
|
assert "image" in mm_data
|
||||||
|
assert mm_data["image"] is None
|
||||||
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
):
|
||||||
|
uuid = "abcd"
|
||||||
|
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||||
|
[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_embeds",
|
||||||
|
"image_embeds": None,
|
||||||
|
"uuid": uuid
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
phi3v_model_config,
|
||||||
|
phi3v_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "<|image_1|>\nWhat's in this image?",
|
||||||
|
}]
|
||||||
|
mm_data = await mm_future
|
||||||
|
assert mm_data is not None
|
||||||
|
assert "image" in mm_data
|
||||||
|
assert mm_data["image"] is None
|
||||||
|
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parse_chat_messages_multiple_images_async(
|
async def test_parse_chat_messages_multiple_images_async(
|
||||||
phi3v_model_config,
|
phi3v_model_config,
|
||||||
@ -1636,6 +1989,118 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
|
|||||||
expected_uuids=["audio_123"])
|
expected_uuids=["audio_123"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501
|
||||||
|
qwen25omni_model_config_mm_interleaved,
|
||||||
|
qwen25omni_tokenizer,
|
||||||
|
image_url,
|
||||||
|
video_url,
|
||||||
|
audio_url,
|
||||||
|
):
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's on this image?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": None,
|
||||||
|
"uuid": "image_123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Now listen to this audio"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "audio_url",
|
||||||
|
"audio_url": None,
|
||||||
|
"uuid": "audio_123",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Some stuff."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's on this image?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": None,
|
||||||
|
"uuid": "image_123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "And what's in the video?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "video_url",
|
||||||
|
"video_url": None,
|
||||||
|
"uuid": "video_123",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
qwen25omni_model_config_mm_interleaved,
|
||||||
|
qwen25omni_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n"
|
||||||
|
"Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Some stuff."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n"
|
||||||
|
"And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
_assert_mm_data_inputs(mm_data, {
|
||||||
|
"image": 2,
|
||||||
|
"video": 1,
|
||||||
|
"audio": 1
|
||||||
|
},
|
||||||
|
skipped_media_indices={
|
||||||
|
"image": [0, 1],
|
||||||
|
"video": [0],
|
||||||
|
"audio": [0]
|
||||||
|
})
|
||||||
|
_assert_mm_uuids(mm_uuids,
|
||||||
|
2,
|
||||||
|
modality="image",
|
||||||
|
expected_uuids=["image_123", "image_123"])
|
||||||
|
_assert_mm_uuids(mm_uuids,
|
||||||
|
1,
|
||||||
|
modality="video",
|
||||||
|
expected_uuids=["video_123"])
|
||||||
|
_assert_mm_uuids(mm_uuids,
|
||||||
|
1,
|
||||||
|
modality="audio",
|
||||||
|
expected_uuids=["audio_123"])
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501
|
def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501
|
||||||
qwen25omni_model_config_mm_interleaved,
|
qwen25omni_model_config_mm_interleaved,
|
||||||
qwen25omni_tokenizer,
|
qwen25omni_tokenizer,
|
||||||
@ -2355,3 +2820,82 @@ def test_apply_mistral_chat_template_thinking_chunk():
|
|||||||
r"[INST]Thanks, what is 3+3?[/INST]")
|
r"[INST]Thanks, what is 3+3?[/INST]")
|
||||||
|
|
||||||
assert string_tokens == expected_tokens
|
assert string_tokens == expected_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_chat_messages_single_empty_audio_with_uuid(
|
||||||
|
qwen2_audio_model_config,
|
||||||
|
qwen2_audio_tokenizer,
|
||||||
|
):
|
||||||
|
audio_uuid = "abcd"
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "input_audio",
|
||||||
|
"input_audio": {},
|
||||||
|
"uuid": audio_uuid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What does the audio say?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
qwen2_audio_model_config,
|
||||||
|
qwen2_audio_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the audio say?"
|
||||||
|
}]
|
||||||
|
_assert_mm_data_inputs(mm_data, {"audio": 1})
|
||||||
|
_assert_mm_uuids(mm_uuids,
|
||||||
|
1,
|
||||||
|
modality="audio",
|
||||||
|
expected_uuids=[audio_uuid])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
|
||||||
|
qwen2_audio_model_config,
|
||||||
|
qwen2_audio_tokenizer,
|
||||||
|
):
|
||||||
|
audio_uuid = "abcd"
|
||||||
|
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||||
|
[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "input_audio",
|
||||||
|
"input_audio": {},
|
||||||
|
"uuid": audio_uuid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What does the audio say?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
qwen2_audio_model_config,
|
||||||
|
qwen2_audio_tokenizer,
|
||||||
|
content_format="string",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation == [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the audio say?"
|
||||||
|
}]
|
||||||
|
_assert_mm_data_inputs(await mm_future, {"audio": 1})
|
||||||
|
_assert_mm_uuids(mm_uuids,
|
||||||
|
1,
|
||||||
|
modality="audio",
|
||||||
|
expected_uuids=[audio_uuid])
|
||||||
|
|||||||
@ -22,7 +22,10 @@ def clear_cache():
|
|||||||
|
|
||||||
# Define MLA and non-MLA backends separately
|
# Define MLA and non-MLA backends separately
|
||||||
DEVICE_MLA_BACKENDS = {
|
DEVICE_MLA_BACKENDS = {
|
||||||
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
|
"cuda": [
|
||||||
|
"TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
|
||||||
|
"CUTLASS_MLA"
|
||||||
|
],
|
||||||
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
||||||
"cpu": [],
|
"cpu": [],
|
||||||
}
|
}
|
||||||
@ -90,8 +93,8 @@ def test_env(
|
|||||||
|
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CpuPlatform()):
|
CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
backend = get_attn_backend(16, torch.float16, None, block_size,
|
||||||
block_size, False)
|
False)
|
||||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||||
|
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
@ -109,7 +112,7 @@ def test_env(
|
|||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
get_attn_backend(16,
|
get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -120,7 +123,7 @@ def test_env(
|
|||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
get_attn_backend(16,
|
get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -130,7 +133,7 @@ def test_env(
|
|||||||
# Valid backend-block_size combination
|
# Valid backend-block_size combination
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -139,7 +142,7 @@ def test_env(
|
|||||||
else:
|
else:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -153,6 +156,8 @@ def test_env(
|
|||||||
# CUDA MLA backend logic:
|
# CUDA MLA backend logic:
|
||||||
# - CUTLASS_MLA: only supported with block_size == 128
|
# - CUTLASS_MLA: only supported with block_size == 128
|
||||||
# and Blackwell GPUs (SM 10.0), V1 only
|
# and Blackwell GPUs (SM 10.0), V1 only
|
||||||
|
# - FLASHINFER_MLA: only supported on Blackwell GPUs
|
||||||
|
# (SM 10.0+), V1 only
|
||||||
# - FLASHMLA: only supported with block_size == 64
|
# - FLASHMLA: only supported with block_size == 64
|
||||||
# - FLASH_ATTN_MLA: V1 only
|
# - FLASH_ATTN_MLA: V1 only
|
||||||
# - TRITON_MLA: fallback for other cases
|
# - TRITON_MLA: fallback for other cases
|
||||||
@ -169,12 +174,31 @@ def test_env(
|
|||||||
else:
|
else:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "CUTLASS_MLA_VLLM_V1"
|
expected = "CUTLASS_MLA_VLLM_V1"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
|
elif name == "FLASHINFER_MLA":
|
||||||
|
if not use_v1:
|
||||||
|
# FlashInfer MLA only supported on V1 engine
|
||||||
|
pytest.skip(
|
||||||
|
"FlashInfer MLA only supported on V1 engine")
|
||||||
|
elif block_size not in [32, 64]:
|
||||||
|
# FlashInfer MLA only supports block_size 32 or 64
|
||||||
|
pytest.skip(
|
||||||
|
"FlashInfer MLA only supports block_size 32 "
|
||||||
|
"or 64")
|
||||||
|
else:
|
||||||
|
backend = get_attn_backend(16,
|
||||||
|
torch.float16,
|
||||||
|
None,
|
||||||
|
block_size,
|
||||||
|
False,
|
||||||
|
use_mla=use_mla)
|
||||||
|
expected = "FLASHINFER_MLA"
|
||||||
|
assert backend.get_name() == expected
|
||||||
elif name == "FLASHMLA":
|
elif name == "FLASHMLA":
|
||||||
if block_size != 64:
|
if block_size != 64:
|
||||||
# FlashMLA only supports block_size == 64
|
# FlashMLA only supports block_size == 64
|
||||||
@ -189,7 +213,7 @@ def test_env(
|
|||||||
else:
|
else:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -204,7 +228,7 @@ def test_env(
|
|||||||
else:
|
else:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -214,7 +238,7 @@ def test_env(
|
|||||||
# TRITON_MLA or other fallback
|
# TRITON_MLA or other fallback
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -224,7 +248,7 @@ def test_env(
|
|||||||
elif name == "FLASHINFER":
|
elif name == "FLASHINFER":
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -233,7 +257,7 @@ def test_env(
|
|||||||
else:
|
else:
|
||||||
backend = get_attn_backend(32,
|
backend = get_attn_backend(32,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -243,7 +267,7 @@ def test_env(
|
|||||||
if use_v1:
|
if use_v1:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -269,15 +293,13 @@ def test_fp32_fallback(
|
|||||||
|
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CpuPlatform()):
|
CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
||||||
16, False)
|
|
||||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||||
|
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CudaPlatform()):
|
CudaPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
||||||
16, False)
|
|
||||||
assert (backend.get_name() == "FLEX_ATTENTION"
|
assert (backend.get_name() == "FLEX_ATTENTION"
|
||||||
if use_v1 else "XFORMERS")
|
if use_v1 else "XFORMERS")
|
||||||
|
|
||||||
@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Attention-free models should bypass env and use PlaceholderAttention
|
# Attention-free models should bypass env and use PlaceholderAttention
|
||||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
|
backend = get_attn_backend(16, torch.float16, None, 16, True)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from itertools import accumulate, product
|
from itertools import product
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -111,151 +111,6 @@ def test_rotary_embedding(
|
|||||||
"expected returned key to be None"
|
"expected returned key to be None"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
|
||||||
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
|
||||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
||||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
||||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
||||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_batched_rotary_embedding(
|
|
||||||
is_neox_style: bool,
|
|
||||||
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
|
|
||||||
batch_size: int,
|
|
||||||
seq_len: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
rotary_dim: Optional[int],
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
use_key: bool,
|
|
||||||
max_position: int = 8192,
|
|
||||||
base: float = 10000,
|
|
||||||
) -> None:
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
if rotary_dim is None:
|
|
||||||
rotary_dim = head_size
|
|
||||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
|
||||||
"rope_type": "linear",
|
|
||||||
"factor": (1, )
|
|
||||||
})
|
|
||||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
|
||||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
|
||||||
query = torch.randn(query_shape, dtype=dtype)
|
|
||||||
key = torch.randn_like(query) if use_key else None
|
|
||||||
|
|
||||||
# slice tensor if required, noop otherwise
|
|
||||||
query = query[..., :head_size]
|
|
||||||
key = key[..., :head_size] if use_key else None
|
|
||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
|
||||||
# because the custom kernel is in-place.
|
|
||||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
|
||||||
out_query, out_key = rope.forward(positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
offsets=torch.zeros(batch_size * seq_len,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device))
|
|
||||||
# Compare the results.
|
|
||||||
torch.testing.assert_close(out_query,
|
|
||||||
ref_query,
|
|
||||||
atol=get_default_atol(out_query),
|
|
||||||
rtol=get_default_rtol(out_query))
|
|
||||||
if use_key:
|
|
||||||
torch.testing.assert_close(out_key,
|
|
||||||
ref_key,
|
|
||||||
atol=get_default_atol(out_key),
|
|
||||||
rtol=get_default_rtol(out_key))
|
|
||||||
else:
|
|
||||||
assert ref_key is None and out_key is None, \
|
|
||||||
"expected returned key to be None"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
|
||||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
||||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
||||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
||||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_batched_rotary_embedding_multi_lora(
|
|
||||||
is_neox_style: bool,
|
|
||||||
batch_size: int,
|
|
||||||
seq_len: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
rotary_dim: Optional[int],
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
use_key: bool,
|
|
||||||
max_position: int = 8192,
|
|
||||||
base: float = 10000,
|
|
||||||
) -> None:
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
if rotary_dim is None:
|
|
||||||
rotary_dim = head_size
|
|
||||||
scaling_factors: list[int] = [1, 2, 4]
|
|
||||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
|
||||||
"rope_type": "linear",
|
|
||||||
"factor": tuple(scaling_factors)
|
|
||||||
})
|
|
||||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
|
||||||
query = torch.randn(batch_size,
|
|
||||||
seq_len,
|
|
||||||
num_heads * head_size,
|
|
||||||
dtype=dtype)
|
|
||||||
key = torch.randn_like(query) if use_key else None
|
|
||||||
|
|
||||||
offset_map = torch.tensor(
|
|
||||||
list(
|
|
||||||
accumulate([0] + [
|
|
||||||
max_position * scaling_factor * 2
|
|
||||||
for scaling_factor in scaling_factors[:-1]
|
|
||||||
])))
|
|
||||||
query_types = torch.randint(0,
|
|
||||||
len(scaling_factors), (batch_size, seq_len),
|
|
||||||
device=device)
|
|
||||||
query_offsets = offset_map[query_types]
|
|
||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
|
||||||
# because the custom kernel is in-place.
|
|
||||||
ref_query, ref_key = rope.forward_native(positions, query, key,
|
|
||||||
query_offsets)
|
|
||||||
out_query, out_key = rope.forward(positions, query, key,
|
|
||||||
query_offsets.flatten())
|
|
||||||
# Compare the results.
|
|
||||||
torch.testing.assert_close(out_query,
|
|
||||||
ref_query,
|
|
||||||
atol=get_default_atol(out_query),
|
|
||||||
rtol=get_default_rtol(out_query))
|
|
||||||
if use_key:
|
|
||||||
torch.testing.assert_close(out_key,
|
|
||||||
ref_key,
|
|
||||||
atol=get_default_atol(out_key),
|
|
||||||
rtol=get_default_rtol(out_key))
|
|
||||||
else:
|
|
||||||
assert ref_key is None and out_key is None, \
|
|
||||||
"expected returned key to be None"
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_rope_module_cache():
|
def test_rope_module_cache():
|
||||||
MAX_POSITIONS = [123, 1234]
|
MAX_POSITIONS = [123, 1234]
|
||||||
|
|||||||
@ -16,20 +16,14 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
|||||||
def rotary_embedding_opcheck(rot,
|
def rotary_embedding_opcheck(rot,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor] = None,
|
key: Optional[torch.Tensor] = None):
|
||||||
offsets: Optional[torch.Tensor] = None):
|
|
||||||
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
|
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||||
|
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
# ops.rotary_embedding() is a in-place operation
|
||||||
# are in-place operations that update the query and key tensors.
|
# that updates the query and key tensors.
|
||||||
if offsets is not None:
|
opcheck(torch.ops._C.rotary_embedding,
|
||||||
opcheck(torch.ops._C.batched_rotary_embedding,
|
(positions, query, key, rot.head_size, cos_sin_cache,
|
||||||
(positions, query, key, rot.head_size, cos_sin_cache,
|
rot.is_neox_style))
|
||||||
rot.is_neox_style, rot.rotary_dim, offsets))
|
|
||||||
else:
|
|
||||||
opcheck(torch.ops._C.rotary_embedding,
|
|
||||||
(positions, query, key, rot.head_size, cos_sin_cache,
|
|
||||||
rot.is_neox_style))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda"])
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
@ -65,10 +59,6 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
|||||||
key = key[..., :head_size] if use_key else None
|
key = key[..., :head_size] if use_key else None
|
||||||
|
|
||||||
rotary_embedding_opcheck(rot, positions, query, key)
|
rotary_embedding_opcheck(rot, positions, query, key)
|
||||||
offsets = torch.zeros(batch_size * seq_len,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.long)
|
|
||||||
rotary_embedding_opcheck(rot, positions, query, key, offsets)
|
|
||||||
|
|
||||||
# if we have a contiguous head stride, test the alternate
|
# if we have a contiguous head stride, test the alternate
|
||||||
# [..., num_heads * head_dim] shape/layout
|
# [..., num_heads * head_dim] shape/layout
|
||||||
|
|||||||
@ -771,11 +771,11 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
|||||||
w13_ref = dequant_mxfp4_batches(
|
w13_ref = dequant_mxfp4_batches(
|
||||||
w13_q.view(torch.uint8),
|
w13_q.view(torch.uint8),
|
||||||
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
||||||
num_experts, 2 * intermediate_size, hidden_size)
|
num_experts, 2 * intermediate_size, hidden_size).to(device)
|
||||||
w2_ref = dequant_mxfp4_batches(
|
w2_ref = dequant_mxfp4_batches(
|
||||||
w2_q.view(torch.uint8),
|
w2_q.view(torch.uint8),
|
||||||
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
||||||
num_experts, hidden_size, intermediate_size)
|
num_experts, hidden_size, intermediate_size).to(device)
|
||||||
|
|
||||||
# Quantize activations for SM100 path and dequantize for reference
|
# Quantize activations for SM100 path and dequantize for reference
|
||||||
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
|
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
|
||||||
|
|||||||
@ -5,28 +5,52 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
silu_mul_fp8_quant_deep_gemm)
|
silu_mul_fp8_quant_deep_gemm_cuda)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
|
fp8_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
# (E, T, H, group_size, seed)
|
|
||||||
CASES = [
|
CASES = [
|
||||||
(1, 1, 128, 64, 0),
|
(1, 1, 128, fp8_dtype),
|
||||||
(1, 4, 128, 128, 0),
|
(1, 4, 128, fp8_dtype),
|
||||||
(2, 4, 256, 128, 0),
|
(2, 4, 256, fp8_dtype),
|
||||||
(32, 64, 256, 128, 0),
|
(32, 64, 256, fp8_dtype),
|
||||||
(17, 31, 768, 128, 0),
|
(17, 31, 768, fp8_dtype),
|
||||||
|
(1, 1, 128 * 1, fp8_dtype),
|
||||||
|
(1, 1, 128 * 2, fp8_dtype),
|
||||||
|
(1, 1, 128 * 3, fp8_dtype),
|
||||||
|
(1, 1, 128 * 4, fp8_dtype),
|
||||||
|
(8, 16, 128 * 1, fp8_dtype),
|
||||||
|
(8, 16, 128 * 2, fp8_dtype),
|
||||||
|
(8, 16, 128 * 3, fp8_dtype),
|
||||||
|
(8, 16, 128 * 4, fp8_dtype),
|
||||||
|
(8, 64, 7168, fp8_dtype),
|
||||||
|
(8, 128, 7168, fp8_dtype),
|
||||||
|
(8, 256, 7168, fp8_dtype),
|
||||||
|
(8, 512, 7168, fp8_dtype),
|
||||||
|
(8, 1024, 7168, fp8_dtype),
|
||||||
|
(256, 8, 7168, fp8_dtype),
|
||||||
|
(256, 16, 7168, fp8_dtype),
|
||||||
|
(256, 32, 7168, fp8_dtype),
|
||||||
|
(256, 64, 7168, fp8_dtype),
|
||||||
|
|
||||||
|
# Only add a few fnuz tests to help with long CI times.
|
||||||
|
(8, 512, 7168, torch.float8_e4m3fnuz),
|
||||||
|
(8, 1024, 7168, torch.float8_e4m3fnuz),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("E,T,H,group_size,seed", CASES)
|
@pytest.mark.parametrize("E,T,H,fp8_type", CASES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
|
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
|
||||||
current_platform.seed_everything(seed)
|
group_size = 128
|
||||||
|
current_platform.seed_everything(42)
|
||||||
|
|
||||||
# Input tensor of shape (E, T, 2*H)
|
# Input tensor of shape (E, T, 2*H)
|
||||||
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
|
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
|
||||||
tokens_per_expert = torch.randint(
|
tokens_per_expert = torch.randint(
|
||||||
low=0,
|
low=T // 2,
|
||||||
high=T,
|
high=T,
|
||||||
size=(E, ),
|
size=(E, ),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@ -34,45 +58,59 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Run the Triton kernel
|
# Run the Triton kernel
|
||||||
y_q, y_s = silu_mul_fp8_quant_deep_gemm(y,
|
y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(y,
|
||||||
tokens_per_expert,
|
tokens_per_expert,
|
||||||
group_size=group_size,
|
group_size=group_size)
|
||||||
eps=1e-10)
|
|
||||||
|
|
||||||
# Reference implementation
|
torch.cuda.synchronize()
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
fp8_info = torch.finfo(fp8_dtype)
|
||||||
fp8_max = fp8_info.max
|
fp8_max = fp8_info.max
|
||||||
fp8_min = fp8_info.min
|
fp8_min = fp8_info.min
|
||||||
eps = 1e-10
|
eps = 1e-10
|
||||||
|
|
||||||
# Compute silu activation and elementwise multiplication
|
y1 = y[..., :H].float()
|
||||||
y1 = y[..., :H]
|
|
||||||
y2 = y[..., H:]
|
y2 = y[..., H:]
|
||||||
silu_x = y1 * torch.sigmoid(y1)
|
silu_x = y1 * torch.sigmoid(y1)
|
||||||
merged = silu_x * y2
|
merged = silu_x * y2
|
||||||
|
|
||||||
# Compute reference scales and quantized output, skipping padded tokens
|
|
||||||
for e in range(E):
|
for e in range(E):
|
||||||
nt = tokens_per_expert[e].item()
|
nt = tokens_per_expert[e].item()
|
||||||
ref_s = torch.empty((T, H // group_size),
|
ref_s = torch.empty((T, cdiv(H, group_size)),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
|
ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda")
|
||||||
|
|
||||||
for t in range(nt):
|
for t in range(nt):
|
||||||
data = merged[e, t]
|
data = merged[e, t].float()
|
||||||
data_grp = data.view(H // group_size, group_size)
|
ref_q_row = torch.empty_like(data)
|
||||||
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
|
|
||||||
scale = amax / fp8_max
|
|
||||||
|
|
||||||
scaled = data / scale.repeat_interleave(group_size)
|
# process full groups
|
||||||
clamped = scaled.clamp(fp8_min, fp8_max)
|
n_full_groups = H // group_size
|
||||||
q = clamped.to(torch.float8_e4m3fn)
|
if n_full_groups > 0:
|
||||||
|
data_grp = data[:n_full_groups * group_size].view(
|
||||||
|
n_full_groups, group_size)
|
||||||
|
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
|
||||||
|
scale = amax / fp8_max
|
||||||
|
scaled = data[:n_full_groups *
|
||||||
|
group_size] / scale.repeat_interleave(group_size)
|
||||||
|
ref_q_row[:n_full_groups * group_size] = scaled.clamp(
|
||||||
|
fp8_min, fp8_max).to(fp8_dtype)
|
||||||
|
ref_s[t, :n_full_groups] = scale
|
||||||
|
|
||||||
ref_s[t] = scale
|
# process remainder group
|
||||||
ref_q[t] = q
|
rem = H % group_size
|
||||||
|
if rem > 0:
|
||||||
|
data_rem = data[-rem:]
|
||||||
|
amax = data_rem.abs().amax().clamp(min=eps)
|
||||||
|
scale = amax / fp8_max
|
||||||
|
scaled = data_rem / scale
|
||||||
|
ref_q_row[-rem:] = scaled.clamp(fp8_min, fp8_max).to(fp8_dtype)
|
||||||
|
ref_s[t, -1] = scale
|
||||||
|
|
||||||
y_se = y_s[e]
|
ref_q[t] = ref_q_row
|
||||||
y_qe = y_q[e]
|
|
||||||
|
y_se = y_s[e].float()
|
||||||
|
y_qe = y_q[e].float()
|
||||||
|
|
||||||
torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2)
|
torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
|
|||||||
@ -301,6 +301,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
|
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
|
||||||
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
|
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
|
||||||
|
"Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"),
|
||||||
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
|
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
|
||||||
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m",
|
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m",
|
||||||
{"1b": "facebook/opt-iml-max-1.3b"}),
|
{"1b": "facebook/opt-iml-max-1.3b"}),
|
||||||
|
|||||||
@ -178,6 +178,7 @@ class MockAttentionLayer:
|
|||||||
self._k_scale = torch.tensor(1.0, device=device)
|
self._k_scale = torch.tensor(1.0, device=device)
|
||||||
self._v_scale = torch.tensor(1.0, device=device)
|
self._v_scale = torch.tensor(1.0, device=device)
|
||||||
# Add float versions for flashinfer
|
# Add float versions for flashinfer
|
||||||
|
self._q_scale_float = 1.0
|
||||||
self._k_scale_float = 1.0
|
self._k_scale_float = 1.0
|
||||||
self._v_scale_float = 1.0
|
self._v_scale_float = 1.0
|
||||||
|
|
||||||
|
|||||||
@ -141,6 +141,8 @@ def get_attention_backend(backend_name: _Backend):
|
|||||||
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
|
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
|
||||||
_Backend.FLASH_ATTN_MLA:
|
_Backend.FLASH_ATTN_MLA:
|
||||||
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
|
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
|
||||||
|
_Backend.FLASHINFER_MLA:
|
||||||
|
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
|
||||||
_Backend.TRITON_MLA_VLLM_V1:
|
_Backend.TRITON_MLA_VLLM_V1:
|
||||||
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
|
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -117,9 +117,9 @@ def test_ngram_correctness(
|
|||||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||||
|
|
||||||
# Heuristic: expect at least 68% of the prompts to match exactly
|
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
assert matches >= int(0.68 * len(ref_outputs))
|
assert matches >= int(0.66 * len(ref_outputs))
|
||||||
del spec_llm
|
del spec_llm
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|||||||
@ -257,9 +257,13 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
|
non_block=False,
|
||||||
) -> Future[ModelRunnerOutput]:
|
) -> Future[ModelRunnerOutput]:
|
||||||
"""Make execute_model non-blocking."""
|
"""Make execute_model non-blocking."""
|
||||||
|
|
||||||
|
# DummyExecutor used only for testing async case.
|
||||||
|
assert non_block
|
||||||
|
|
||||||
def _execute():
|
def _execute():
|
||||||
output = self.collective_rpc("execute_model",
|
output = self.collective_rpc("execute_model",
|
||||||
args=(scheduler_output, ))
|
args=(scheduler_output, ))
|
||||||
|
|||||||
@ -33,10 +33,12 @@ def test_ragged_paged_attention():
|
|||||||
)
|
)
|
||||||
|
|
||||||
class FakeAttentionLayer:
|
class FakeAttentionLayer:
|
||||||
|
_q_scale_float: float
|
||||||
_k_scale_float: float
|
_k_scale_float: float
|
||||||
_v_scale_float: float
|
_v_scale_float: float
|
||||||
|
|
||||||
layer = FakeAttentionLayer()
|
layer = FakeAttentionLayer()
|
||||||
|
layer._q_scale_float = 1.0
|
||||||
layer._k_scale_float = 1.0
|
layer._k_scale_float = 1.0
|
||||||
layer._v_scale_float = 1.0
|
layer._v_scale_float = 1.0
|
||||||
|
|
||||||
|
|||||||
@ -257,16 +257,6 @@ def rotary_embedding(
|
|||||||
cos_sin_cache, is_neox)
|
cos_sin_cache, is_neox)
|
||||||
|
|
||||||
|
|
||||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
|
||||||
key: Optional[torch.Tensor], head_size: int,
|
|
||||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
|
||||||
rot_dim: int,
|
|
||||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
|
||||||
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
|
|
||||||
cos_sin_cache, is_neox, rot_dim,
|
|
||||||
cos_sin_cache_offsets)
|
|
||||||
|
|
||||||
|
|
||||||
# layer norm ops
|
# layer norm ops
|
||||||
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||||
epsilon: float) -> None:
|
epsilon: float) -> None:
|
||||||
|
|||||||
@ -148,17 +148,6 @@ class ipex_ops:
|
|||||||
head_size, cos_sin_cache,
|
head_size, cos_sin_cache,
|
||||||
is_neox, rot_dim)
|
is_neox, rot_dim)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
|
||||||
key: torch.Tensor, head_size: int,
|
|
||||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
|
||||||
rot_dim: int,
|
|
||||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
|
||||||
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
|
||||||
head_size, cos_sin_cache,
|
|
||||||
is_neox, rot_dim,
|
|
||||||
cos_sin_cache_offsets)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
||||||
epsilon: float) -> torch.Tensor:
|
epsilon: float) -> torch.Tensor:
|
||||||
|
|||||||
@ -240,6 +240,7 @@ class AttentionLayer(Protocol):
|
|||||||
_q_scale: torch.Tensor
|
_q_scale: torch.Tensor
|
||||||
_k_scale: torch.Tensor
|
_k_scale: torch.Tensor
|
||||||
_v_scale: torch.Tensor
|
_v_scale: torch.Tensor
|
||||||
|
_q_scale_float: float
|
||||||
_k_scale_float: float
|
_k_scale_float: float
|
||||||
_v_scale_float: float
|
_v_scale_float: float
|
||||||
_prob_scale: torch.Tensor
|
_prob_scale: torch.Tensor
|
||||||
|
|||||||
@ -68,6 +68,7 @@ class RequestFuncInput:
|
|||||||
model: str
|
model: str
|
||||||
model_name: Optional[str] = None
|
model_name: Optional[str] = None
|
||||||
logprobs: Optional[int] = None
|
logprobs: Optional[int] = None
|
||||||
|
extra_headers: Optional[dict] = None
|
||||||
extra_body: Optional[dict] = None
|
extra_body: Optional[dict] = None
|
||||||
multi_modal_content: Optional[Union[dict, list[dict]]] = None
|
multi_modal_content: Optional[Union[dict, list[dict]]] = None
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
@ -129,6 +130,8 @@ async def async_request_openai_completions(
|
|||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||||
}
|
}
|
||||||
|
if request_func_input.extra_headers:
|
||||||
|
headers |= request_func_input.extra_headers
|
||||||
if request_func_input.request_id:
|
if request_func_input.request_id:
|
||||||
headers["x-request-id"] = request_func_input.request_id
|
headers["x-request-id"] = request_func_input.request_id
|
||||||
|
|
||||||
@ -258,6 +261,8 @@ async def async_request_openai_chat_completions(
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||||
}
|
}
|
||||||
|
if request_func_input.extra_headers:
|
||||||
|
headers |= request_func_input.extra_headers
|
||||||
if request_func_input.request_id:
|
if request_func_input.request_id:
|
||||||
headers["x-request-id"] = request_func_input.request_id
|
headers["x-request-id"] = request_func_input.request_id
|
||||||
|
|
||||||
@ -364,6 +369,8 @@ async def async_request_openai_audio(
|
|||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||||
}
|
}
|
||||||
|
if request_func_input.extra_headers:
|
||||||
|
headers |= request_func_input.extra_headers
|
||||||
if request_func_input.request_id:
|
if request_func_input.request_id:
|
||||||
headers["x-request-id"] = request_func_input.request_id
|
headers["x-request-id"] = request_func_input.request_id
|
||||||
|
|
||||||
|
|||||||
@ -389,6 +389,7 @@ async def benchmark(
|
|||||||
goodput_config_dict: dict[str, float],
|
goodput_config_dict: dict[str, float],
|
||||||
max_concurrency: Optional[int],
|
max_concurrency: Optional[int],
|
||||||
lora_modules: Optional[Iterable[str]],
|
lora_modules: Optional[Iterable[str]],
|
||||||
|
extra_headers: Optional[dict],
|
||||||
extra_body: Optional[dict],
|
extra_body: Optional[dict],
|
||||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||||
ramp_up_start_rps: Optional[int] = None,
|
ramp_up_start_rps: Optional[int] = None,
|
||||||
@ -452,6 +453,7 @@ async def benchmark(
|
|||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
multi_modal_content=test_mm_content,
|
multi_modal_content=test_mm_content,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
|
extra_headers=extra_headers,
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -484,6 +486,7 @@ async def benchmark(
|
|||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
multi_modal_content=test_mm_content,
|
multi_modal_content=test_mm_content,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
|
extra_headers=extra_headers,
|
||||||
extra_body=extra_body)
|
extra_body=extra_body)
|
||||||
profile_output = await request_func(
|
profile_output = await request_func(
|
||||||
request_func_input=profile_input, session=session)
|
request_func_input=profile_input, session=session)
|
||||||
@ -568,6 +571,7 @@ async def benchmark(
|
|||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
multi_modal_content=mm_content,
|
multi_modal_content=mm_content,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
|
extra_headers=extra_headers,
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
request_id=request_id,)
|
request_id=request_id,)
|
||||||
tasks.append(
|
tasks.append(
|
||||||
@ -815,6 +819,15 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
default="/v1/completions",
|
default="/v1/completions",
|
||||||
help="API endpoint.",
|
help="API endpoint.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--header",
|
||||||
|
metavar="KEY=VALUE",
|
||||||
|
nargs="*",
|
||||||
|
help="Key-value pairs (e.g, --header x-additional-info=0.3.3) "
|
||||||
|
"for headers to be passed with each request. These headers override " \
|
||||||
|
"per backend constants and values set via environment variable, and " \
|
||||||
|
"will be overriden by other arguments (such as request ids)."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-concurrency",
|
"--max-concurrency",
|
||||||
type=int,
|
type=int,
|
||||||
@ -1104,6 +1117,19 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||||
base_url = f"http://{args.host}:{args.port}"
|
base_url = f"http://{args.host}:{args.port}"
|
||||||
|
|
||||||
|
# Headers
|
||||||
|
headers = None
|
||||||
|
if args.header:
|
||||||
|
headers = {}
|
||||||
|
for item in args.header:
|
||||||
|
if "=" in item:
|
||||||
|
kvstring = item.split("=", 1)
|
||||||
|
headers[kvstring[0].strip()] = kvstring[1].strip()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid header format. Please use KEY=VALUE format."
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer = get_tokenizer(tokenizer_id,
|
tokenizer = get_tokenizer(tokenizer_id,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
trust_remote_code=args.trust_remote_code)
|
trust_remote_code=args.trust_remote_code)
|
||||||
@ -1161,6 +1187,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
goodput_config_dict=goodput_config_dict,
|
goodput_config_dict=goodput_config_dict,
|
||||||
max_concurrency=args.max_concurrency,
|
max_concurrency=args.max_concurrency,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
|
extra_headers=headers,
|
||||||
extra_body=sampling_params,
|
extra_body=sampling_params,
|
||||||
ramp_up_strategy=args.ramp_up_strategy,
|
ramp_up_strategy=args.ramp_up_strategy,
|
||||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||||
@ -1184,7 +1211,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
if args.metadata:
|
if args.metadata:
|
||||||
for item in args.metadata:
|
for item in args.metadata:
|
||||||
if "=" in item:
|
if "=" in item:
|
||||||
kvstring = item.split("=")
|
kvstring = item.split("=", 1)
|
||||||
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -182,7 +182,7 @@ class NaiveBlockAllocator(BlockAllocator):
|
|||||||
# Increment refcount for each block.
|
# Increment refcount for each block.
|
||||||
assert block.block_id is not None
|
assert block.block_id is not None
|
||||||
refcount = self._refcounter.incr(block.block_id)
|
refcount = self._refcounter.incr(block.block_id)
|
||||||
assert refcount != 1, "can't fork free'd block"
|
assert refcount != 1, "can't fork freed block"
|
||||||
|
|
||||||
forked_block = self._block_pool.init_block(
|
forked_block = self._block_pool.init_block(
|
||||||
prev_block=prev_block,
|
prev_block=prev_block,
|
||||||
|
|||||||
@ -58,7 +58,7 @@ class Evictor(ABC):
|
|||||||
|
|
||||||
class BlockMetaData:
|
class BlockMetaData:
|
||||||
"""Data structure for storing key data describe cached block, so that
|
"""Data structure for storing key data describe cached block, so that
|
||||||
evitor could use to make its decision which one to choose for eviction
|
evictor could use to make its decision which one to choose for eviction
|
||||||
|
|
||||||
Here we use physical block id as the dict key, as there maybe several
|
Here we use physical block id as the dict key, as there maybe several
|
||||||
blocks with the same content hash, but their physical id is unique.
|
blocks with the same content hash, but their physical id is unique.
|
||||||
|
|||||||
@ -337,11 +337,11 @@ class EplbState:
|
|||||||
Args:
|
Args:
|
||||||
model (MixtureOfExperts): The MoE model.
|
model (MixtureOfExperts): The MoE model.
|
||||||
is_dummy (bool): If `True`, this is a dummy step and the load
|
is_dummy (bool): If `True`, this is a dummy step and the load
|
||||||
metrics recorded in this forward pass will not count. Defaults
|
metrics recorded in this forward pass will not count. Defaults
|
||||||
to `False`.
|
to `False`.
|
||||||
is_profile (bool): If `True`, perform a dummy rearrangement
|
is_profile (bool): If `True`, perform a dummy rearrangement
|
||||||
with maximum communication cost. This is used in `profile_run`
|
with maximum communication cost. This is used in `profile_run`
|
||||||
to reserve enough memory for the communication buffer.
|
to reserve enough memory for the communication buffer.
|
||||||
log_stats (bool): If `True`, log the expert load metrics.
|
log_stats (bool): If `True`, log the expert load metrics.
|
||||||
|
|
||||||
# Stats
|
# Stats
|
||||||
|
|||||||
@ -102,14 +102,14 @@ def rebalance_experts_hierarchical(
|
|||||||
num_groups: int,
|
num_groups: int,
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_gpus: int,
|
num_gpus: int,
|
||||||
):
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
weight: [num_moe_layers, num_logical_experts]
|
weight: [num_moe_layers, num_logical_experts]
|
||||||
num_physical_experts: number of physical experts after replication
|
num_physical_experts: number of physical experts after replication
|
||||||
num_groups: number of expert groups
|
num_groups: number of expert groups
|
||||||
num_nodes: number of server nodes, where the intra-node network
|
num_nodes: number of server nodes, where the intra-node network
|
||||||
(e.g, NVLink) is faster
|
(e.g, NVLink) is faster
|
||||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@ -149,7 +149,7 @@ class KVConnectorBase_V1(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def start_load_kv(self, forward_context: "ForwardContext",
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
**kwargs) -> None:
|
**kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Start loading the KV cache from the connector to vLLM's paged
|
Start loading the KV cache from the connector to vLLM's paged
|
||||||
KV buffer. This is called from the forward context before the
|
KV buffer. This is called from the forward context before the
|
||||||
@ -182,7 +182,8 @@ class KVConnectorBase_V1(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
attn_metadata: "AttentionMetadata",
|
||||||
|
**kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Start saving a layer of KV cache from vLLM's paged buffer
|
Start saving a layer of KV cache from vLLM's paged buffer
|
||||||
to the connector. This is called from within attention layer to
|
to the connector. This is called from within attention layer to
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
|||||||
# Worker-side methods
|
# Worker-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
def start_load_kv(self, forward_context: "ForwardContext",
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
**kwargs) -> None:
|
**kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Start loading the KV cache from the connector to vLLM's paged
|
Start loading the KV cache from the connector to vLLM's paged
|
||||||
KV buffer. This is called from the forward context before the
|
KV buffer. This is called from the forward context before the
|
||||||
@ -61,7 +61,8 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
|||||||
self._lmcache_engine.wait_for_layer_load(layer_name)
|
self._lmcache_engine.wait_for_layer_load(layer_name)
|
||||||
|
|
||||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
attn_metadata: "AttentionMetadata",
|
||||||
|
**kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Start saving the a layer of KV cache from vLLM's paged buffer
|
Start saving the a layer of KV cache from vLLM's paged buffer
|
||||||
to the connector. This is called from within attention layer to
|
to the connector. This is called from within attention layer to
|
||||||
|
|||||||
@ -708,8 +708,6 @@ class NixlConnectorWorker:
|
|||||||
caches_data = []
|
caches_data = []
|
||||||
# With hybrid allocator, layers can share a kv cache tensor
|
# With hybrid allocator, layers can share a kv cache tensor
|
||||||
seen_base_addresses = []
|
seen_base_addresses = []
|
||||||
xfer_buffers = (self.host_xfer_buffers
|
|
||||||
if self.use_host_buffer else kv_caches)
|
|
||||||
|
|
||||||
# Note(tms): I modified this from the original region setup code.
|
# Note(tms): I modified this from the original region setup code.
|
||||||
# K and V are now in different regions. Advantage is that we can
|
# K and V are now in different regions. Advantage is that we can
|
||||||
|
|||||||
@ -91,7 +91,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
# ==============================
|
# ==============================
|
||||||
|
|
||||||
def start_load_kv(self, forward_context: "ForwardContext",
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
**kwargs) -> None:
|
**kwargs: Any) -> None:
|
||||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||||
paged KV buffer.
|
paged KV buffer.
|
||||||
|
|
||||||
@ -212,7 +212,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
attn_metadata: "AttentionMetadata",
|
||||||
|
**kwargs: Any) -> None:
|
||||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||||
to the connector.
|
to the connector.
|
||||||
|
|
||||||
@ -278,7 +279,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
def get_finished(
|
def get_finished(
|
||||||
self, finished_req_ids: set[str],
|
self, finished_req_ids: set[str],
|
||||||
**kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
**kwargs: Any) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||||
"""
|
"""
|
||||||
Notifies worker-side connector ids of requests that have
|
Notifies worker-side connector ids of requests that have
|
||||||
finished generating tokens.
|
finished generating tokens.
|
||||||
|
|||||||
@ -218,8 +218,9 @@ class TensorMemoryPool:
|
|||||||
|
|
||||||
return addr
|
return addr
|
||||||
|
|
||||||
def load_tensor(self, addr: int, dtype: torch.dtype,
|
def load_tensor(self, addr: int, dtype: torch.dtype, shape: tuple[int,
|
||||||
shape: tuple[int, ...], device) -> torch.Tensor:
|
...],
|
||||||
|
device: torch.device) -> torch.Tensor:
|
||||||
"""Loads a tensor from pinned host memory to the specified device.
|
"""Loads a tensor from pinned host memory to the specified device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
@ -90,7 +90,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
logger.info("Shared storage path is %s", self._storage_path)
|
logger.info("Shared storage path is %s", self._storage_path)
|
||||||
|
|
||||||
def start_load_kv(self, forward_context: "ForwardContext",
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
**kwargs) -> None:
|
**kwargs: Any) -> None:
|
||||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||||
paged KV buffer.
|
paged KV buffer.
|
||||||
|
|
||||||
@ -191,7 +191,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
attn_metadata: "AttentionMetadata",
|
||||||
|
**kwargs: Any) -> None:
|
||||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||||
to the connector.
|
to the connector.
|
||||||
|
|
||||||
|
|||||||
@ -251,8 +251,8 @@ class PyNcclPipe(KVPipeBase):
|
|||||||
"""
|
"""
|
||||||
Receives a tensor and its metadata from the source rank. Blocking call.
|
Receives a tensor and its metadata from the source rank. Blocking call.
|
||||||
|
|
||||||
Args:
|
Returns:
|
||||||
tensor: The received tensor, or `None` if no tensor is received.
|
The received tensor, or `None` if no tensor is received.
|
||||||
"""
|
"""
|
||||||
if self.transport_thread is None:
|
if self.transport_thread is None:
|
||||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|||||||
@ -1296,11 +1296,8 @@ class EngineArgs:
|
|||||||
# Async scheduling does not work with the uniprocess backend.
|
# Async scheduling does not work with the uniprocess backend.
|
||||||
if self.distributed_executor_backend is None:
|
if self.distributed_executor_backend is None:
|
||||||
self.distributed_executor_backend = "mp"
|
self.distributed_executor_backend = "mp"
|
||||||
logger.info("Using mp-based distributed executor backend "
|
logger.info("Defaulting to mp-based distributed executor "
|
||||||
"for async scheduling.")
|
"backend for async scheduling.")
|
||||||
if self.distributed_executor_backend == "uni":
|
|
||||||
raise ValueError("Async scheduling is not supported with "
|
|
||||||
"uni-process backend.")
|
|
||||||
if self.pipeline_parallel_size > 1:
|
if self.pipeline_parallel_size > 1:
|
||||||
raise ValueError("Async scheduling is not supported with "
|
raise ValueError("Async scheduling is not supported with "
|
||||||
"pipeline-parallel-size > 1.")
|
"pipeline-parallel-size > 1.")
|
||||||
|
|||||||
@ -379,7 +379,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||||
self.local_interval):
|
self.local_interval):
|
||||||
# Compute summary metrics for tracked stats (and log them
|
# Compute summary metrics for tracked stats (and log them
|
||||||
# to promethus if applicable).
|
# to prometheus if applicable).
|
||||||
prompt_throughput = get_throughput(self.num_prompt_tokens,
|
prompt_throughput = get_throughput(self.num_prompt_tokens,
|
||||||
now=stats.now,
|
now=stats.now,
|
||||||
last_log=self.last_local_log)
|
last_log=self.last_local_log)
|
||||||
@ -432,7 +432,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
|
|
||||||
|
|
||||||
class PrometheusStatLogger(StatLoggerBase):
|
class PrometheusStatLogger(StatLoggerBase):
|
||||||
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
|
"""PrometheusStatLogger is used LLMEngine to log to Prometheus."""
|
||||||
_metrics_cls = Metrics
|
_metrics_cls = Metrics
|
||||||
_gauge_cls = prometheus_client.Gauge
|
_gauge_cls = prometheus_client.Gauge
|
||||||
|
|
||||||
|
|||||||
@ -73,15 +73,10 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
|||||||
|
|
||||||
type: Required[Literal["audio_url"]]
|
type: Required[Literal["audio_url"]]
|
||||||
"""The type of the content part."""
|
"""The type of the content part."""
|
||||||
uuid: Optional[str]
|
|
||||||
"""
|
|
||||||
User-provided UUID of a media. User must guarantee that it is properly
|
|
||||||
generated and unique for different medias.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
||||||
image_embeds: Required[Union[str, dict[str, str]]]
|
image_embeds: Optional[Union[str, dict[str, str]]]
|
||||||
"""
|
"""
|
||||||
The image embeddings. It can be either:
|
The image embeddings. It can be either:
|
||||||
- A single base64 string.
|
- A single base64 string.
|
||||||
@ -108,11 +103,6 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
|
|||||||
|
|
||||||
type: Required[Literal["video_url"]]
|
type: Required[Literal["video_url"]]
|
||||||
"""The type of the content part."""
|
"""The type of the content part."""
|
||||||
uuid: Optional[str]
|
|
||||||
"""
|
|
||||||
User-provided UUID of a media. User must guarantee that it is properly
|
|
||||||
generated and unique for different medias.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class PILImage(BaseModel):
|
class PILImage(BaseModel):
|
||||||
@ -133,7 +123,7 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_pil: Required[PILImage]
|
image_pil: Optional[PILImage]
|
||||||
uuid: Optional[str]
|
uuid: Optional[str]
|
||||||
"""
|
"""
|
||||||
User-provided UUID of a media. User must guarantee that it is properly
|
User-provided UUID of a media. User must guarantee that it is properly
|
||||||
@ -151,7 +141,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_url: Required[str]
|
image_url: Optional[str]
|
||||||
uuid: Optional[str]
|
uuid: Optional[str]
|
||||||
"""
|
"""
|
||||||
User-provided UUID of a media. User must guarantee that it is properly
|
User-provided UUID of a media. User must guarantee that it is properly
|
||||||
@ -168,7 +158,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio_url: Required[str]
|
audio_url: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
||||||
@ -180,7 +170,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
video_url: Required[str]
|
video_url: Optional[str]
|
||||||
uuid: Optional[str]
|
uuid: Optional[str]
|
||||||
"""
|
"""
|
||||||
User-provided UUID of a media. User must guarantee that it is properly
|
User-provided UUID of a media. User must guarantee that it is properly
|
||||||
@ -597,7 +587,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
self._model_config = model_config
|
self._model_config = model_config
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
|
|
||||||
self._items_by_modality = defaultdict[str, list[_T]](list)
|
self._items_by_modality = defaultdict[str, list[Optional[_T]]](list)
|
||||||
self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
|
self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -624,14 +614,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
return self.mm_registry.create_processor(self.model_config)
|
return self.mm_registry.create_processor(self.model_config)
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self, modality: ModalityStr, item: _T, uuid: Optional[str] = None
|
self,
|
||||||
|
modality: ModalityStr,
|
||||||
|
item: Optional[_T],
|
||||||
|
uuid: Optional[str] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Add a multi-modal item to the current prompt and returns the
|
Add a multi-modal item to the current prompt and returns the
|
||||||
placeholder string to use, if any.
|
placeholder string to use, if any.
|
||||||
|
|
||||||
An optional uuid can be added which serves as a unique identifier of the
|
An optional uuid can be added which serves as a unique identifier of the
|
||||||
media.
|
media.
|
||||||
"""
|
"""
|
||||||
input_modality = modality.replace("_embeds", "")
|
input_modality = modality.replace("_embeds", "")
|
||||||
num_items = len(self._items_by_modality[modality]) + 1
|
num_items = len(self._items_by_modality[modality]) + 1
|
||||||
@ -708,10 +701,15 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
|||||||
if not self._items_by_modality:
|
if not self._items_by_modality:
|
||||||
return None
|
return None
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
items_by_modality = {
|
items_by_modality = {}
|
||||||
modality: await asyncio.gather(*items)
|
for modality, items in self._items_by_modality.items():
|
||||||
for modality, items in self._items_by_modality.items()
|
coros = []
|
||||||
}
|
for item in items:
|
||||||
|
if item is not None:
|
||||||
|
coros.append(item)
|
||||||
|
else:
|
||||||
|
coros.append(asyncio.sleep(0))
|
||||||
|
items_by_modality[modality] = await asyncio.gather(*coros)
|
||||||
|
|
||||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -760,35 +758,40 @@ class BaseMultiModalContentParser(ABC):
|
|||||||
return dict(self._placeholder_storage)
|
return dict(self._placeholder_storage)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
|
def parse_image(
|
||||||
|
self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_image_embeds(
|
def parse_image_embeds(
|
||||||
self,
|
self,
|
||||||
image_embeds: Union[str, dict[str, str]],
|
image_embeds: Union[str, dict[str, str], None],
|
||||||
uuid: Optional[str] = None,
|
uuid: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_image_pil(
|
def parse_image_pil(
|
||||||
self, image_pil: Image.Image, uuid: Optional[str] = None
|
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
|
def parse_audio(
|
||||||
|
self, audio_url: Optional[str], uuid: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_input_audio(
|
def parse_input_audio(
|
||||||
self, input_audio: InputAudio, uuid: Optional[str] = None
|
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
|
def parse_video(
|
||||||
|
self, video_url: Optional[str], uuid: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -803,15 +806,17 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
|
def parse_image(
|
||||||
image = self._connector.fetch_image(image_url)
|
self, image_url: Optional[str], uuid: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
image = self._connector.fetch_image(image_url) if image_url else None
|
||||||
|
|
||||||
placeholder = self._tracker.add("image", image, uuid)
|
placeholder = self._tracker.add("image", image, uuid)
|
||||||
self._add_placeholder("image", placeholder)
|
self._add_placeholder("image", placeholder)
|
||||||
|
|
||||||
def parse_image_embeds(
|
def parse_image_embeds(
|
||||||
self,
|
self,
|
||||||
image_embeds: Union[str, dict[str, str]],
|
image_embeds: Union[str, dict[str, str], None],
|
||||||
uuid: Optional[str] = None,
|
uuid: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(image_embeds, dict):
|
if isinstance(image_embeds, dict):
|
||||||
@ -825,31 +830,49 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||||
placeholder = self._tracker.add("image_embeds", embedding, uuid)
|
placeholder = self._tracker.add("image_embeds", embedding, uuid)
|
||||||
|
|
||||||
|
if image_embeds is None:
|
||||||
|
placeholder = self._tracker.add("image_embeds", None, uuid)
|
||||||
|
|
||||||
self._add_placeholder("image", placeholder)
|
self._add_placeholder("image", placeholder)
|
||||||
|
|
||||||
def parse_image_pil(
|
def parse_image_pil(
|
||||||
self, image_pil: Image.Image, uuid: Optional[str] = None
|
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
placeholder = self._tracker.add("image", image_pil, uuid)
|
placeholder = self._tracker.add("image", image_pil, uuid)
|
||||||
self._add_placeholder("image", placeholder)
|
self._add_placeholder("image", placeholder)
|
||||||
|
|
||||||
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
|
def parse_audio(
|
||||||
audio = self._connector.fetch_audio(audio_url)
|
self, audio_url: Optional[str], uuid: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
audio = self._connector.fetch_audio(audio_url) if audio_url else None
|
||||||
|
|
||||||
placeholder = self._tracker.add("audio", audio, uuid)
|
placeholder = self._tracker.add("audio", audio, uuid)
|
||||||
self._add_placeholder("audio", placeholder)
|
self._add_placeholder("audio", placeholder)
|
||||||
|
|
||||||
def parse_input_audio(
|
def parse_input_audio(
|
||||||
self, input_audio: InputAudio, uuid: Optional[str] = None
|
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
audio_data = input_audio.get("data", "")
|
if input_audio:
|
||||||
audio_format = input_audio.get("format", "")
|
audio_data = input_audio.get("data", "")
|
||||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
audio_format = input_audio.get("format", "")
|
||||||
|
if audio_data:
|
||||||
|
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||||
|
else:
|
||||||
|
# If a UUID is provided, audio data may be empty.
|
||||||
|
audio_url = None
|
||||||
|
else:
|
||||||
|
audio_url = None
|
||||||
|
|
||||||
return self.parse_audio(audio_url, uuid)
|
return self.parse_audio(audio_url, uuid)
|
||||||
|
|
||||||
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
|
def parse_video(
|
||||||
video = self._connector.fetch_video(video_url=video_url)
|
self, video_url: Optional[str], uuid: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
video = (
|
||||||
|
self._connector.fetch_video(video_url=video_url)
|
||||||
|
if video_url
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
placeholder = self._tracker.add("video", video, uuid)
|
placeholder = self._tracker.add("video", video, uuid)
|
||||||
self._add_placeholder("video", placeholder)
|
self._add_placeholder("video", placeholder)
|
||||||
@ -865,18 +888,24 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
|
def parse_image(
|
||||||
image_coro = self._connector.fetch_image_async(image_url)
|
self, image_url: Optional[str], uuid: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
image_coro = (
|
||||||
|
self._connector.fetch_image_async(image_url) if image_url else None
|
||||||
|
)
|
||||||
|
|
||||||
placeholder = self._tracker.add("image", image_coro, uuid)
|
placeholder = self._tracker.add("image", image_coro, uuid)
|
||||||
self._add_placeholder("image", placeholder)
|
self._add_placeholder("image", placeholder)
|
||||||
|
|
||||||
def parse_image_embeds(
|
def parse_image_embeds(
|
||||||
self,
|
self,
|
||||||
image_embeds: Union[str, dict[str, str]],
|
image_embeds: Union[str, dict[str, str], None],
|
||||||
uuid: Optional[str] = None,
|
uuid: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
|
future: asyncio.Future[Union[str, dict[str, str], None]] = (
|
||||||
|
asyncio.Future()
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(image_embeds, dict):
|
if isinstance(image_embeds, dict):
|
||||||
embeds = {
|
embeds = {
|
||||||
@ -889,35 +918,58 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||||
future.set_result(embedding)
|
future.set_result(embedding)
|
||||||
|
|
||||||
|
if image_embeds is None:
|
||||||
|
future.set_result(None)
|
||||||
|
|
||||||
placeholder = self._tracker.add("image_embeds", future, uuid)
|
placeholder = self._tracker.add("image_embeds", future, uuid)
|
||||||
self._add_placeholder("image", placeholder)
|
self._add_placeholder("image", placeholder)
|
||||||
|
|
||||||
def parse_image_pil(
|
def parse_image_pil(
|
||||||
self, image_pil: Image.Image, uuid: Optional[str] = None
|
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
future: asyncio.Future[Image.Image] = asyncio.Future()
|
future: asyncio.Future[Optional[Image.Image]] = asyncio.Future()
|
||||||
future.set_result(image_pil)
|
if image_pil:
|
||||||
|
future.set_result(image_pil)
|
||||||
|
else:
|
||||||
|
future.set_result(None)
|
||||||
|
|
||||||
placeholder = self._tracker.add("image", future, uuid)
|
placeholder = self._tracker.add("image", future, uuid)
|
||||||
self._add_placeholder("image", placeholder)
|
self._add_placeholder("image", placeholder)
|
||||||
|
|
||||||
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
|
def parse_audio(
|
||||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
self, audio_url: Optional[str], uuid: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
audio_coro = (
|
||||||
|
self._connector.fetch_audio_async(audio_url) if audio_url else None
|
||||||
|
)
|
||||||
|
|
||||||
placeholder = self._tracker.add("audio", audio_coro, uuid)
|
placeholder = self._tracker.add("audio", audio_coro, uuid)
|
||||||
self._add_placeholder("audio", placeholder)
|
self._add_placeholder("audio", placeholder)
|
||||||
|
|
||||||
def parse_input_audio(
|
def parse_input_audio(
|
||||||
self, input_audio: InputAudio, uuid: Optional[str] = None
|
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
audio_data = input_audio.get("data", "")
|
if input_audio:
|
||||||
audio_format = input_audio.get("format", "")
|
audio_data = input_audio.get("data", "")
|
||||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
audio_format = input_audio.get("format", "")
|
||||||
|
if audio_data:
|
||||||
|
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||||
|
else:
|
||||||
|
# If a UUID is provided, audio data may be empty.
|
||||||
|
audio_url = None
|
||||||
|
else:
|
||||||
|
audio_url = None
|
||||||
|
|
||||||
return self.parse_audio(audio_url, uuid)
|
return self.parse_audio(audio_url, uuid)
|
||||||
|
|
||||||
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
|
def parse_video(
|
||||||
video = self._connector.fetch_video_async(video_url=video_url)
|
self, video_url: Optional[str], uuid: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
video = (
|
||||||
|
self._connector.fetch_video_async(video_url=video_url)
|
||||||
|
if video_url
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
placeholder = self._tracker.add("video", video, uuid)
|
placeholder = self._tracker.add("video", video, uuid)
|
||||||
self._add_placeholder("video", placeholder)
|
self._add_placeholder("video", placeholder)
|
||||||
@ -1130,8 +1182,9 @@ def _parse_chat_message_content_mm_part(
|
|||||||
part, dict
|
part, dict
|
||||||
) # This is needed to avoid mypy errors: part.get() from str
|
) # This is needed to avoid mypy errors: part.get() from str
|
||||||
part_type = part.get("type", None)
|
part_type = part.get("type", None)
|
||||||
|
uuid = part.get("uuid", None)
|
||||||
|
|
||||||
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
|
if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501
|
||||||
content = MM_PARSER_MAP[part_type](part)
|
content = MM_PARSER_MAP[part_type](part)
|
||||||
|
|
||||||
# Special case for 'image_url.detail'
|
# Special case for 'image_url.detail'
|
||||||
@ -1146,25 +1199,54 @@ def _parse_chat_message_content_mm_part(
|
|||||||
|
|
||||||
# Handle missing 'type' but provided direct URL fields.
|
# Handle missing 'type' but provided direct URL fields.
|
||||||
# 'type' is required field by pydantic
|
# 'type' is required field by pydantic
|
||||||
if part_type is None:
|
if part_type is None or uuid is not None:
|
||||||
if part.get("image_url") is not None:
|
if "image_url" in part:
|
||||||
image_params = cast(
|
image_params = cast(
|
||||||
CustomChatCompletionContentSimpleImageParam, part
|
CustomChatCompletionContentSimpleImageParam, part
|
||||||
)
|
)
|
||||||
return "image_url", image_params.get("image_url", "")
|
image_url = image_params.get("image_url", None)
|
||||||
if part.get("audio_url") is not None:
|
if isinstance(image_url, dict):
|
||||||
|
# Can potentially happen if user provides a uuid
|
||||||
|
# with url as a dict of {"url": url}
|
||||||
|
image_url = image_url.get("url", None)
|
||||||
|
return "image_url", image_url
|
||||||
|
if "image_pil" in part:
|
||||||
|
# "image_pil" could be None if UUID is provided.
|
||||||
|
image_params = cast( # type: ignore
|
||||||
|
CustomChatCompletionContentPILImageParam, part
|
||||||
|
)
|
||||||
|
image_pil = image_params.get("image_pil", None)
|
||||||
|
return "image_pil", image_pil
|
||||||
|
if "image_embeds" in part:
|
||||||
|
# "image_embeds" could be None if UUID is provided.
|
||||||
|
image_params = cast( # type: ignore
|
||||||
|
ChatCompletionContentPartImageEmbedsParam, part
|
||||||
|
)
|
||||||
|
image_embeds = image_params.get("image_embeds", None)
|
||||||
|
return "image_embeds", image_embeds
|
||||||
|
if "audio_url" in part:
|
||||||
audio_params = cast(
|
audio_params = cast(
|
||||||
CustomChatCompletionContentSimpleAudioParam, part
|
CustomChatCompletionContentSimpleAudioParam, part
|
||||||
)
|
)
|
||||||
return "audio_url", audio_params.get("audio_url", "")
|
audio_url = audio_params.get("audio_url", None)
|
||||||
|
if isinstance(audio_url, dict):
|
||||||
|
# Can potentially happen if user provides a uuid
|
||||||
|
# with url as a dict of {"url": url}
|
||||||
|
audio_url = audio_url.get("url", None)
|
||||||
|
return "audio_url", audio_url
|
||||||
if part.get("input_audio") is not None:
|
if part.get("input_audio") is not None:
|
||||||
input_audio_params = cast(dict[str, str], part)
|
input_audio_params = cast(dict[str, str], part)
|
||||||
return "input_audio", input_audio_params
|
return "input_audio", input_audio_params
|
||||||
if part.get("video_url") is not None:
|
if "video_url" in part:
|
||||||
video_params = cast(
|
video_params = cast(
|
||||||
CustomChatCompletionContentSimpleVideoParam, part
|
CustomChatCompletionContentSimpleVideoParam, part
|
||||||
)
|
)
|
||||||
return "video_url", video_params.get("video_url", "")
|
video_url = video_params.get("video_url", None)
|
||||||
|
if isinstance(video_url, dict):
|
||||||
|
# Can potentially happen if user provides a uuid
|
||||||
|
# with url as a dict of {"url": url}
|
||||||
|
video_url = video_url.get("url", None)
|
||||||
|
return "video_url", video_url
|
||||||
# Raise an error if no 'type' or direct URL is found.
|
# Raise an error if no 'type' or direct URL is found.
|
||||||
raise ValueError("Missing 'type' field in multimodal part.")
|
raise ValueError("Missing 'type' field in multimodal part.")
|
||||||
|
|
||||||
@ -1173,15 +1255,9 @@ def _parse_chat_message_content_mm_part(
|
|||||||
return part_type, "unknown part_type content"
|
return part_type, "unknown part_type content"
|
||||||
|
|
||||||
|
|
||||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = (
|
PART_TYPES_TO_SKIP_NONE_CONTENT = (
|
||||||
"text",
|
"text",
|
||||||
"refusal",
|
"refusal",
|
||||||
"image_url",
|
|
||||||
"image_embeds",
|
|
||||||
"image_pil",
|
|
||||||
"audio_url",
|
|
||||||
"input_audio",
|
|
||||||
"video_url",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1242,7 +1318,7 @@ def _parse_chat_message_content_part(
|
|||||||
part_type, content = _parse_chat_message_content_mm_part(part)
|
part_type, content = _parse_chat_message_content_mm_part(part)
|
||||||
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
|
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
|
||||||
# content is None, log a warning and skip
|
# content is None, log a warning and skip
|
||||||
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
|
if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping multimodal part '%s' (type: '%s') "
|
"Skipping multimodal part '%s' (type: '%s') "
|
||||||
"with empty / unparsable content.",
|
"with empty / unparsable content.",
|
||||||
@ -1266,7 +1342,10 @@ def _parse_chat_message_content_part(
|
|||||||
|
|
||||||
modality = None
|
modality = None
|
||||||
if part_type == "image_pil":
|
if part_type == "image_pil":
|
||||||
image_content = cast(Image.Image, content)
|
if content is not None:
|
||||||
|
image_content = cast(Image.Image, content)
|
||||||
|
else:
|
||||||
|
image_content = None
|
||||||
mm_parser.parse_image_pil(image_content, uuid)
|
mm_parser.parse_image_pil(image_content, uuid)
|
||||||
modality = "image"
|
modality = "image"
|
||||||
elif part_type in ("image_url", "input_image"):
|
elif part_type in ("image_url", "input_image"):
|
||||||
@ -1274,7 +1353,10 @@ def _parse_chat_message_content_part(
|
|||||||
mm_parser.parse_image(str_content, uuid)
|
mm_parser.parse_image(str_content, uuid)
|
||||||
modality = "image"
|
modality = "image"
|
||||||
elif part_type == "image_embeds":
|
elif part_type == "image_embeds":
|
||||||
content = cast(Union[str, dict[str, str]], content)
|
if content is not None:
|
||||||
|
content = cast(Union[str, dict[str, str]], content)
|
||||||
|
else:
|
||||||
|
content = None
|
||||||
mm_parser.parse_image_embeds(content, uuid)
|
mm_parser.parse_image_embeds(content, uuid)
|
||||||
modality = "image"
|
modality = "image"
|
||||||
elif part_type == "audio_url":
|
elif part_type == "audio_url":
|
||||||
|
|||||||
@ -1491,6 +1491,11 @@ class LLM:
|
|||||||
|
|
||||||
for i, prompt in enumerate(it):
|
for i, prompt in enumerate(it):
|
||||||
|
|
||||||
|
if isinstance(prompt, dict):
|
||||||
|
self._validate_mm_data_and_uuids(
|
||||||
|
prompt.get("multi_modal_data"),
|
||||||
|
prompt.get("multi_modal_uuids"))
|
||||||
|
|
||||||
param = params[i] if isinstance(params, Sequence) else params
|
param = params[i] if isinstance(params, Sequence) else params
|
||||||
|
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
tokenization_kwargs: dict[str, Any] = {}
|
||||||
@ -1507,6 +1512,41 @@ class LLM:
|
|||||||
priority=priority[i] if priority else 0,
|
priority=priority[i] if priority else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _validate_mm_data_and_uuids(
|
||||||
|
self,
|
||||||
|
multi_modal_data: Optional[Any], # MultiModalDataDict
|
||||||
|
multi_modal_uuids: Optional[Any], # MultiModalUUIDDict
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Validate that if any multi-modal data is skipped (i.e. None),
|
||||||
|
then its corresponding UUID must be set.
|
||||||
|
"""
|
||||||
|
if multi_modal_data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for modality, data in multi_modal_data.items():
|
||||||
|
if isinstance(data, list):
|
||||||
|
for i, d in enumerate(data):
|
||||||
|
if d is None:
|
||||||
|
if multi_modal_uuids is None or modality not in multi_modal_uuids or multi_modal_uuids[ # noqa: E501
|
||||||
|
modality] is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multi-modal data for {modality} is None "
|
||||||
|
f"but UUID is not provided")
|
||||||
|
else:
|
||||||
|
if len(
|
||||||
|
multi_modal_uuids[modality]
|
||||||
|
) <= i or multi_modal_uuids[modality][i] is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multi-modal data for {modality} is None "
|
||||||
|
f"but UUID is not provided")
|
||||||
|
else:
|
||||||
|
if data is None and (multi_modal_uuids is None
|
||||||
|
or modality not in multi_modal_uuids
|
||||||
|
or multi_modal_uuids[modality] is None):
|
||||||
|
raise ValueError(f"Multi-modal data for {modality} is None"
|
||||||
|
f" but UUID is not provided")
|
||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
|
|||||||
@ -476,6 +476,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# - "FLASHINFER": use flashinfer
|
# - "FLASHINFER": use flashinfer
|
||||||
# - "FLASHMLA": use FlashMLA
|
# - "FLASHMLA": use FlashMLA
|
||||||
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
|
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
|
||||||
|
# - "FLASHINFER_MLA": use FlashInfer for MLA
|
||||||
|
# - "CUTLASS_MLA": use CUTLASS for MLA
|
||||||
"VLLM_ATTENTION_BACKEND":
|
"VLLM_ATTENTION_BACKEND":
|
||||||
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
|
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
|
from functools import cached_property
|
||||||
from multiprocessing import Lock
|
from multiprocessing import Lock
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
|||||||
run_method)
|
run_method)
|
||||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||||
from vllm.v1.executor.utils import get_and_update_mm_cache
|
from vllm.v1.executor.utils import get_and_update_mm_cache
|
||||||
|
from vllm.v1.outputs import AsyncModelRunnerOutput
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -31,15 +33,7 @@ class UniProcExecutor(ExecutorBase):
|
|||||||
"""
|
"""
|
||||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
||||||
rpc_rank=0)
|
rpc_rank=0)
|
||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method, rank, local_rank = self._distributed_args()
|
||||||
get_ip(), get_open_port())
|
|
||||||
local_rank = 0
|
|
||||||
# set local rank as the device index if specified
|
|
||||||
device_info = self.vllm_config.device_config.device.__str__().split(
|
|
||||||
":")
|
|
||||||
if len(device_info) > 1:
|
|
||||||
local_rank = int(device_info[1])
|
|
||||||
rank = 0
|
|
||||||
is_driver_worker = True
|
is_driver_worker = True
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
@ -50,21 +44,56 @@ class UniProcExecutor(ExecutorBase):
|
|||||||
)
|
)
|
||||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||||
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
|
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
|
||||||
|
|
||||||
|
self.async_output_thread: Optional[ThreadPoolExecutor] = None
|
||||||
|
if self.max_concurrent_batches > 1:
|
||||||
|
self.async_output_thread = ThreadPoolExecutor(
|
||||||
|
max_workers=1, thread_name_prefix="WorkerAsyncOutput")
|
||||||
|
|
||||||
self.collective_rpc("init_worker", args=([kwargs], ))
|
self.collective_rpc("init_worker", args=([kwargs], ))
|
||||||
self.collective_rpc("init_device")
|
self.collective_rpc("init_device")
|
||||||
self.collective_rpc("load_model")
|
self.collective_rpc("load_model")
|
||||||
|
|
||||||
|
def _distributed_args(self) -> tuple[str, int, int]:
|
||||||
|
"""Return (distributed_init_method, rank, local_rank)."""
|
||||||
|
distributed_init_method = get_distributed_init_method(
|
||||||
|
get_ip(), get_open_port())
|
||||||
|
# set local rank as the device index if specified
|
||||||
|
device_info = self.vllm_config.device_config.device.__str__().split(
|
||||||
|
":")
|
||||||
|
local_rank = int(device_info[1]) if len(device_info) > 1 else 0
|
||||||
|
return distributed_init_method, 0, local_rank
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def max_concurrent_batches(self) -> int:
|
||||||
|
return 2 if self.scheduler_config.async_scheduling else 1
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable],
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict] = None,
|
||||||
|
non_block: bool = False) -> List[Any]:
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if self.mm_receiver_cache is not None and method == "execute_model":
|
if self.mm_receiver_cache is not None and method == "execute_model":
|
||||||
get_and_update_mm_cache(self.mm_receiver_cache, args)
|
get_and_update_mm_cache(self.mm_receiver_cache, args)
|
||||||
answer = run_method(self.driver_worker, method, args, kwargs)
|
|
||||||
return [answer]
|
if not non_block:
|
||||||
|
return [run_method(self.driver_worker, method, args, kwargs)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = run_method(self.driver_worker, method, args, kwargs)
|
||||||
|
if isinstance(result, AsyncModelRunnerOutput):
|
||||||
|
if (async_thread := self.async_output_thread) is not None:
|
||||||
|
return [async_thread.submit(result.get_output)]
|
||||||
|
result = result.get_output()
|
||||||
|
future = Future[Any]()
|
||||||
|
future.set_result(result)
|
||||||
|
except Exception as e:
|
||||||
|
future = Future[Any]()
|
||||||
|
future.set_exception(e)
|
||||||
|
return [future]
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
# UniProcExecutor will always be healthy as long as
|
# UniProcExecutor will always be healthy as long as
|
||||||
@ -116,8 +145,9 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
|
|||||||
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
|
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
|
||||||
("To get deterministic execution in V1, "
|
("To get deterministic execution in V1, "
|
||||||
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
|
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
|
||||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
super()._init_executor()
|
||||||
rpc_rank=0)
|
|
||||||
|
def _distributed_args(self) -> tuple[str, int, int]:
|
||||||
# engines are launched in torchrun-compatible launchers
|
# engines are launched in torchrun-compatible launchers
|
||||||
# so we can use the env:// method.
|
# so we can use the env:// method.
|
||||||
# required env vars:
|
# required env vars:
|
||||||
@ -128,19 +158,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
|
|||||||
distributed_init_method = "env://"
|
distributed_init_method = "env://"
|
||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
is_driver_worker = True
|
return distributed_init_method, rank, local_rank
|
||||||
kwargs = dict(
|
|
||||||
vllm_config=self.vllm_config,
|
|
||||||
local_rank=local_rank,
|
|
||||||
rank=rank,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
is_driver_worker=is_driver_worker,
|
|
||||||
)
|
|
||||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
|
||||||
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
|
|
||||||
self.collective_rpc("init_worker", args=([kwargs], ))
|
|
||||||
self.collective_rpc("init_device")
|
|
||||||
self.collective_rpc("load_model")
|
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from math import log2
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -10,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
|||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceDelegate)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
|
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
|
||||||
is_deep_gemm_e8m0_used)
|
is_deep_gemm_e8m0_used)
|
||||||
@ -24,35 +26,28 @@ def _silu_mul_fp8_quant_deep_gemm(
|
|||||||
y_q_ptr, # fp8 quantized activations (E, T, H)
|
y_q_ptr, # fp8 quantized activations (E, T, H)
|
||||||
y_s_ptr, # 16-bit scales (E, T, G)
|
y_s_ptr, # 16-bit scales (E, T, G)
|
||||||
counts_ptr, # int32 num tokens per expert (E)
|
counts_ptr, # int32 num tokens per expert (E)
|
||||||
|
|
||||||
# Sizes ---------------------------------------------------------------
|
# Sizes ---------------------------------------------------------------
|
||||||
H: tl.constexpr, # hidden dimension (per output)
|
H: tl.constexpr, # hidden dimension (per output)
|
||||||
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
|
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
|
||||||
|
|
||||||
# Strides for input (elements) ---------------------------------------
|
# Strides for input (elements) ---------------------------------------
|
||||||
stride_i_e,
|
stride_i_e,
|
||||||
stride_i_t,
|
stride_i_t,
|
||||||
stride_i_h,
|
stride_i_h,
|
||||||
|
|
||||||
# Strides for y_q (elements) -----------------------------------------
|
# Strides for y_q (elements) -----------------------------------------
|
||||||
stride_yq_e,
|
stride_yq_e,
|
||||||
stride_yq_t,
|
stride_yq_t,
|
||||||
stride_yq_h,
|
stride_yq_h,
|
||||||
|
|
||||||
# Strides for y_s (elements) -----------------------------------------
|
# Strides for y_s (elements) -----------------------------------------
|
||||||
stride_ys_e,
|
stride_ys_e,
|
||||||
stride_ys_t,
|
stride_ys_t,
|
||||||
stride_ys_g,
|
stride_ys_g,
|
||||||
|
|
||||||
# Stride for counts (elements)
|
# Stride for counts (elements)
|
||||||
stride_counts_e,
|
stride_counts_e,
|
||||||
|
|
||||||
# Numeric params ------------------------------------------------------
|
# Numeric params ------------------------------------------------------
|
||||||
eps: tl.constexpr,
|
eps: tl.constexpr,
|
||||||
fp8_min: tl.constexpr,
|
fp8_min: tl.constexpr,
|
||||||
fp8_max: tl.constexpr,
|
fp8_max: tl.constexpr,
|
||||||
use_ue8m0: tl.constexpr,
|
use_ue8m0: tl.constexpr,
|
||||||
|
|
||||||
# Meta ---------------------------------------------------------------
|
# Meta ---------------------------------------------------------------
|
||||||
BLOCK: tl.constexpr,
|
BLOCK: tl.constexpr,
|
||||||
NUM_STAGES: tl.constexpr,
|
NUM_STAGES: tl.constexpr,
|
||||||
@ -101,17 +96,15 @@ def _silu_mul_fp8_quant_deep_gemm(
|
|||||||
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
||||||
|
|
||||||
|
|
||||||
def silu_mul_fp8_quant_deep_gemm(
|
def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||||
y: torch.Tensor, # (E, T, 2*H)
|
y: torch.Tensor, # (E, T, 2*H)
|
||||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||||
|
num_parallel_tokens=16,
|
||||||
group_size: int = 128,
|
group_size: int = 128,
|
||||||
eps: float = 1e-10,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||||
|
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
|
||||||
silu-activated, multiplied by the second half, then quantized into FP8.
|
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||||
|
|
||||||
Returns `(y_q, y_s)` where
|
Returns `(y_q, y_s)` where
|
||||||
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||||
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||||
@ -120,22 +113,17 @@ def silu_mul_fp8_quant_deep_gemm(
|
|||||||
E, T, H2 = y.shape
|
E, T, H2 = y.shape
|
||||||
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
|
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
|
||||||
H = H2 // 2
|
H = H2 // 2
|
||||||
G = H // group_size
|
G = (H + group_size - 1) // group_size
|
||||||
assert H % group_size == 0, "H must be divisible by group_size"
|
assert H % 8 == 0, "H must be divisible by 8"
|
||||||
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \
|
assert group_size == 128, "H must be divisible by 8"
|
||||||
"tokens_per_expert must be shape (E,)"
|
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E
|
||||||
|
|
||||||
tokens_per_expert = tokens_per_expert.to(device=y.device,
|
tokens_per_expert = tokens_per_expert.to(device=y.device,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
# allocate outputs
|
|
||||||
fp8_dtype = torch.float8_e4m3fn
|
fp8_dtype = torch.float8_e4m3fn
|
||||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||||
|
|
||||||
# strides (elements)
|
|
||||||
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
|
||||||
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
|
||||||
|
|
||||||
# desired scale strides (elements): (T*G, 1, T)
|
|
||||||
stride_ys_e = T * G
|
stride_ys_e = T * G
|
||||||
stride_ys_t = 1
|
stride_ys_t = 1
|
||||||
stride_ys_g = T
|
stride_ys_g = T
|
||||||
@ -144,47 +132,86 @@ def silu_mul_fp8_quant_deep_gemm(
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=y.device)
|
device=y.device)
|
||||||
|
|
||||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||||
|
|
||||||
# Static grid over experts and H-groups.
|
if E <= 16:
|
||||||
# A loop inside the kernel handles the token dim
|
max_empirical_parallelism = 64
|
||||||
grid = (E * G, )
|
elif E <= 32:
|
||||||
|
max_empirical_parallelism = 16
|
||||||
|
else:
|
||||||
|
max_empirical_parallelism = 4
|
||||||
|
|
||||||
f_info = torch.finfo(fp8_dtype)
|
# We never want to launch more than Tx number of threads
|
||||||
fp8_max = f_info.max
|
# This computes the clip.
|
||||||
fp8_min = f_info.min
|
num_parallel_tokens = max(
|
||||||
|
1,
|
||||||
|
min(max_empirical_parallelism, 2**int(log2(min(num_parallel_tokens,
|
||||||
|
T)))))
|
||||||
|
cuda_arch = current_platform.get_device_capability(
|
||||||
|
device_id=y.device.index).to_int()
|
||||||
|
|
||||||
_silu_mul_fp8_quant_deep_gemm[grid](
|
if cuda_arch >= 80:
|
||||||
y,
|
torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda(y, tokens_per_expert,
|
||||||
y_q,
|
y_q, y_s, group_size,
|
||||||
y_s,
|
use_ue8m0,
|
||||||
tokens_per_expert,
|
num_parallel_tokens)
|
||||||
H,
|
else:
|
||||||
group_size,
|
# Default to triton if not on cuda or if arch is too old
|
||||||
stride_i_e,
|
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||||
stride_i_t,
|
|
||||||
stride_i_h,
|
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||||
stride_yq_e,
|
|
||||||
stride_yq_t,
|
# Static grid over experts and H-groups.
|
||||||
stride_yq_h,
|
# A loop inside the kernel handles the token dim
|
||||||
stride_ys_e,
|
grid = (E * G, )
|
||||||
stride_ys_t,
|
# strides (elements)
|
||||||
stride_ys_g,
|
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||||
stride_cnt_e,
|
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||||
eps,
|
|
||||||
fp8_min,
|
# desired scale strides (elements): (T*G, 1, T)
|
||||||
fp8_max,
|
stride_ys_e = T * G
|
||||||
is_deep_gemm_e8m0_used(),
|
stride_ys_t = 1
|
||||||
BLOCK=group_size,
|
stride_ys_g = T
|
||||||
NUM_STAGES=4,
|
y_s = torch.empty_strided(
|
||||||
num_warps=1,
|
(E, T, G),
|
||||||
)
|
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=y.device,
|
||||||
|
)
|
||||||
|
f_info = torch.finfo(fp8_dtype)
|
||||||
|
fp8_max = f_info.max
|
||||||
|
fp8_min = f_info.min
|
||||||
|
eps: float = 1e-10
|
||||||
|
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||||
|
y,
|
||||||
|
y_q,
|
||||||
|
y_s,
|
||||||
|
tokens_per_expert,
|
||||||
|
H,
|
||||||
|
group_size,
|
||||||
|
stride_i_e,
|
||||||
|
stride_i_t,
|
||||||
|
stride_i_h,
|
||||||
|
stride_yq_e,
|
||||||
|
stride_yq_t,
|
||||||
|
stride_yq_h,
|
||||||
|
stride_ys_e,
|
||||||
|
stride_ys_t,
|
||||||
|
stride_ys_g,
|
||||||
|
stride_cnt_e,
|
||||||
|
eps,
|
||||||
|
fp8_min,
|
||||||
|
fp8_max,
|
||||||
|
is_deep_gemm_e8m0_used(),
|
||||||
|
BLOCK=group_size,
|
||||||
|
NUM_STAGES=4,
|
||||||
|
num_warps=1,
|
||||||
|
)
|
||||||
|
|
||||||
return y_q, y_s
|
return y_q, y_s
|
||||||
|
|
||||||
|
|
||||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
# The Deep Gemm kernels only support block size of 128
|
# The Deep Gemm kernels only support block size of 128
|
||||||
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
|
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
|
||||||
|
|
||||||
@ -297,8 +324,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
|
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
|
||||||
workspace1, expert_num_tokens, expected_m)
|
workspace1, expert_num_tokens, expected_m)
|
||||||
|
|
||||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
|
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
|
||||||
expert_num_tokens)
|
workspace1, expert_num_tokens)
|
||||||
|
|
||||||
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
|
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
|
||||||
expert_num_tokens, expected_m)
|
expert_num_tokens, expected_m)
|
||||||
|
|||||||
@ -740,7 +740,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
"""
|
"""
|
||||||
Handle special case for models where MLP layers are already
|
Handle special case for models where MLP layers are already
|
||||||
fused on disk. In this case, we have no shard id. This function
|
fused on disk. In this case, we have no shard id. This function
|
||||||
determmines the shard id by splitting these layers and then calls
|
determines the shard id by splitting these layers and then calls
|
||||||
the weight loader using the shard id.
|
the weight loader using the shard id.
|
||||||
|
|
||||||
An example of a model with these fused layers:
|
An example of a model with these fused layers:
|
||||||
@ -914,7 +914,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
"""
|
"""
|
||||||
Handle special case for models where QKV layers are already
|
Handle special case for models where QKV layers are already
|
||||||
fused on disk. In this case, we have no shard id. This function
|
fused on disk. In this case, we have no shard id. This function
|
||||||
determmines the shard id by splitting these layers and then calls
|
determines the shard id by splitting these layers and then calls
|
||||||
the weight loader using the shard id.
|
the weight loader using the shard id.
|
||||||
|
|
||||||
An example of a model with these fused layers:
|
An example of a model with these fused layers:
|
||||||
|
|||||||
@ -88,6 +88,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
"Setting it to k_scale. This only matters for "
|
"Setting it to k_scale. This only matters for "
|
||||||
"the flash-attn backend.")
|
"the flash-attn backend.")
|
||||||
layer._q_scale.copy_(k_scale)
|
layer._q_scale.copy_(k_scale)
|
||||||
|
layer._q_scale_float = k_scale
|
||||||
|
|
||||||
# These are used in the final Attention.forward()
|
# These are used in the final Attention.forward()
|
||||||
layer._k_scale.copy_(k_scale)
|
layer._k_scale.copy_(k_scale)
|
||||||
@ -124,6 +125,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
|
|
||||||
# These are used in the final Attention.forward()
|
# These are used in the final Attention.forward()
|
||||||
layer._q_scale.copy_(q_scale)
|
layer._q_scale.copy_(q_scale)
|
||||||
|
layer._q_scale_float = q_scale
|
||||||
layer._prob_scale.copy_(prob_scale)
|
layer._prob_scale.copy_(prob_scale)
|
||||||
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
|
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
|
||||||
or prob_scale == 1.0):
|
or prob_scale == 1.0):
|
||||||
|
|||||||
@ -62,11 +62,8 @@ class RotaryEmbedding(CustomOp):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor] = None,
|
key: Optional[torch.Tensor] = None,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""A PyTorch-native implementation of forward()."""
|
"""A PyTorch-native implementation of forward()."""
|
||||||
if offsets is not None:
|
|
||||||
positions = positions + offsets
|
|
||||||
positions = positions.flatten()
|
positions = positions.flatten()
|
||||||
num_tokens = positions.shape[0]
|
num_tokens = positions.shape[0]
|
||||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||||
@ -96,7 +93,6 @@ class RotaryEmbedding(CustomOp):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor] = None,
|
key: Optional[torch.Tensor] = None,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
@ -107,16 +103,10 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||||
dtype=query.dtype)
|
dtype=query.dtype)
|
||||||
|
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
# ops.rotary_embedding() is an in-place operation
|
||||||
# are in-place operations that update the query and key tensors.
|
# that updates the query and key tensors.
|
||||||
if offsets is not None:
|
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||||
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
self.cos_sin_cache,
|
|
||||||
self.is_neox_style, self.rotary_dim,
|
|
||||||
offsets)
|
|
||||||
else:
|
|
||||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
|
||||||
self.cos_sin_cache, self.is_neox_style)
|
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
def forward_xpu(
|
def forward_xpu(
|
||||||
@ -124,29 +114,21 @@ class RotaryEmbedding(CustomOp):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor] = None,
|
key: Optional[torch.Tensor] = None,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
from vllm._ipex_ops import ipex_ops as ops
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||||
dtype=query.dtype)
|
dtype=query.dtype)
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
# ops.rotary_embedding() is an in-place operation
|
||||||
# are in-place operations that update the query and key tensors.
|
# that updates the query and key tensors.
|
||||||
if key is None:
|
if key is None:
|
||||||
# XPU kernel doesn't support key=None so fall back to native impl
|
# XPU kernel doesn't support key=None so fall back to native impl
|
||||||
# TODO(sarckk): add support for optional key in
|
# TODO(sarckk): add support for optional key in
|
||||||
# ipex.llm.functional.rotary_embedding_batched
|
# ipex.llm.functional.rotary_embedding_batched
|
||||||
return self.forward_native(positions, query, key, offsets)
|
return self.forward_native(positions, query, key)
|
||||||
else:
|
else:
|
||||||
if offsets is not None:
|
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||||
ops.batched_rotary_embedding(positions, query, key,
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
self.head_size,
|
|
||||||
self.cos_sin_cache,
|
|
||||||
self.is_neox_style,
|
|
||||||
self.rotary_dim, offsets)
|
|
||||||
else:
|
|
||||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
|
||||||
self.cos_sin_cache, self.is_neox_style)
|
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
|
|||||||
@ -258,7 +258,7 @@ class VocabParallelEmbedding(CustomOp):
|
|||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
# Divide the weight matrix along the vocaburaly dimension.
|
# Divide the weight matrix along the vocabulary dimension.
|
||||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||||
self.tp_size)
|
self.tp_size)
|
||||||
|
|||||||
@ -1446,7 +1446,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# The result multimodal_embeddings is tuple of tensors, with each
|
# The result multimodal_embeddings is tuple of tensors, with each
|
||||||
# tensor correspoending to a multimodal data item (image or video).
|
# tensor corresponding to a multimodal data item (image or video).
|
||||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||||
|
|
||||||
# NOTE: It is important to iterate over the keys in this dictionary
|
# NOTE: It is important to iterate over the keys in this dictionary
|
||||||
|
|||||||
@ -586,10 +586,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
# ruff: noqa
|
# ruff: noqa
|
||||||
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
|
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
|
||||||
# text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
|
# text to account for this. However, the audio preprocessing and encoder do not guarantee they will
|
||||||
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
|
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
|
||||||
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
|
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
|
||||||
# the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
|
# the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab.
|
||||||
# TODO precompute and cache padding
|
# TODO precompute and cache padding
|
||||||
audio_padding_toks = torch.tensor([[self.vocab_size - 1]],
|
audio_padding_toks = torch.tensor([[self.vocab_size - 1]],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
|
|||||||
@ -823,7 +823,7 @@ class SupportsEagle3(Protocol):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
layers: Tuple of layer indices that should output auxiliary
|
layers: Tuple of layer indices that should output auxiliary
|
||||||
hidden states.
|
hidden states.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|||||||
@ -1520,15 +1520,9 @@ class BaseKeyeModule(nn.Module):
|
|||||||
batch.
|
batch.
|
||||||
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
||||||
opensource models), the shape will be `(3, seq_len)`,
|
opensource models), the shape will be `(3, seq_len)`,
|
||||||
otherwise it will be `(seq_len,).
|
otherwise it will be `(seq_len,)`.
|
||||||
pixel_values: Pixel values to be fed to a model.
|
intermediate_tensors: Intermediate tensors from prior forward pass.
|
||||||
`None` if no images are passed.
|
inputs_embeds: Optional tensor of input embeddings.
|
||||||
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
|
||||||
`None` if no images are passed.
|
|
||||||
pixel_values_videos: Pixel values of videos to be fed to a model.
|
|
||||||
`None` if no videos are passed.
|
|
||||||
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
|
|
||||||
`None` if no videos are passed.
|
|
||||||
"""
|
"""
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
|||||||
@ -58,17 +58,18 @@ def split_thw(grid_thw: torch.Tensor) -> torch.Tensor:
|
|||||||
return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0)
|
return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def get_num_patches(grid_thw: torch.Tensor, num_frames: Union[list[int],
|
def get_num_patches(grid_thw: torch.Tensor,
|
||||||
torch.Tensor]):
|
num_frames: Union[list[int], torch.Tensor]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Return num_patches per video.
|
Return num_patches per video.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
t: tensor with shape [N, ...] where each item is a list/tensor
|
grid_thw: Tensor with shape [N, 3] containing temporal, height, width
|
||||||
cu_seqlens: list indicating the boundaries of groups
|
dimensions
|
||||||
|
num_frames: List or tensor indicating the number of frames per video
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of ints representing the sum of products for each group
|
List of ints representing the number of patches for each video
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # Suppose there are 2 videos with a total of 3 grids
|
>>> # Suppose there are 2 videos with a total of 3 grids
|
||||||
|
|||||||
@ -732,7 +732,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
Args:
|
Args:
|
||||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||||
batch.
|
batch.
|
||||||
pixel_values: The pixels in each input image.
|
positions: Position indices for the input tokens.
|
||||||
|
intermediate_tensors: Intermediate tensors from prior forward pass.
|
||||||
|
inputs_embeds: Optional tensor of input embeddings.
|
||||||
|
|
||||||
Info:
|
Info:
|
||||||
[LlavaImageInputs][]
|
[LlavaImageInputs][]
|
||||||
|
|||||||
@ -535,8 +535,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
Args:
|
Args:
|
||||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||||
batch.
|
batch.
|
||||||
pixel_values: The pixels in each grid patch for each input image.
|
positions: Position indices for the input tokens.
|
||||||
image_sizes: The original `(height, width)` for each input image.
|
intermediate_tensors: Intermediate tensors from prior forward pass.
|
||||||
|
inputs_embeds: Optional tensor of input embeddings.
|
||||||
|
|
||||||
Info:
|
Info:
|
||||||
[LlavaNextImageInputs][]
|
[LlavaNextImageInputs][]
|
||||||
|
|||||||
@ -578,7 +578,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
|
|||||||
Args:
|
Args:
|
||||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||||
batch.
|
batch.
|
||||||
pixel_values: The pixels in each input image.
|
positions: Position indices for the input tokens.
|
||||||
|
intermediate_tensors: Intermediate tensors from prior forward pass.
|
||||||
|
inputs_embeds: Optional tensor of input embeddings.
|
||||||
|
|
||||||
Info:
|
Info:
|
||||||
[Mistral3ImagePixelInputs][]
|
[Mistral3ImagePixelInputs][]
|
||||||
|
|||||||
@ -387,11 +387,10 @@ class Llama4VisionEncoder(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
inputs_embeds (`torch.FloatTensor` of shape
|
hidden_states: Input tensor of shape
|
||||||
`(batch_size, sequence_length, hidden_size)`):
|
(batch_size, sequence_length, hidden_size).
|
||||||
Optionally, instead of passing `input_ids` you can choose to
|
Hidden states from the model embeddings, representing
|
||||||
directly pass an embedded representation. This is useful if you
|
the input tokens.
|
||||||
want more control over how to convert `input_ids` indices into
|
|
||||||
associated vectors than the model's internal embedding
|
associated vectors than the model's internal embedding
|
||||||
lookup matrix.
|
lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -70,11 +70,15 @@ def multihead_attention(
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
q_cu_seqlens: Optional[torch.Tensor] = None,
|
q_cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
k_cu_seqlens: Optional[torch.Tensor] = None,
|
k_cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
"""Multi-head attention using flash attention 2.
|
"""Multi-head attention using flash attention 2.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||||
|
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||||
|
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||||
|
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||||
|
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||||
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
|
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
|
||||||
The first element should be 0 and the last element should be q.shape[0].
|
The first element should be 0 and the last element should be q.shape[0].
|
||||||
@ -123,8 +127,14 @@ def sdpa_attention(
|
|||||||
"""SDPA attention.
|
"""SDPA attention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||||
|
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||||
|
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||||
|
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||||
|
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||||
|
q_cu_seqlens: Optional cumulative sequence lengths of q.
|
||||||
|
k_cu_seqlens: Optional cumulative sequence lengths of k.
|
||||||
"""
|
"""
|
||||||
seq_length = q.shape[0]
|
seq_length = q.shape[0]
|
||||||
attention_mask = torch.zeros([1, seq_length, seq_length],
|
attention_mask = torch.zeros([1, seq_length, seq_length],
|
||||||
@ -387,7 +397,7 @@ class MLP2(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
dims: list[int],
|
dims: list[int],
|
||||||
activation,
|
activation,
|
||||||
bias=True,
|
bias: bool = True,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False):
|
use_data_parallel: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -560,7 +560,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# The result multimodal_embeddings is tuple of tensors, with each
|
# The result multimodal_embeddings is tuple of tensors, with each
|
||||||
# tensor correspoending to a multimodal data item (image).
|
# tensor corresponding to a multimodal data item (image).
|
||||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||||
|
|
||||||
# NOTE: It is important to iterate over the keys in this dictionary
|
# NOTE: It is important to iterate over the keys in this dictionary
|
||||||
|
|||||||
@ -52,10 +52,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
AutoWeightsLoader, is_pp_missing_parameter,
|
AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.transformers_utils.configs import Olmo3Config
|
||||||
|
|
||||||
|
|
||||||
class Olmo2Attention(nn.Module):
|
class Olmo2Attention(nn.Module):
|
||||||
@ -68,7 +69,7 @@ class Olmo2Attention(nn.Module):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = vllm_config.model_config.hf_config
|
self.config = vllm_config.model_config.hf_config
|
||||||
assert isinstance(self.config, Olmo2Config)
|
assert isinstance(self.config, (Olmo2Config, Olmo3Config))
|
||||||
|
|
||||||
hidden_size = self.config.hidden_size
|
hidden_size = self.config.hidden_size
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -111,14 +112,14 @@ class Olmo2Attention(nn.Module):
|
|||||||
self.q_norm = RMSNorm(self.config.hidden_size,
|
self.q_norm = RMSNorm(self.config.hidden_size,
|
||||||
eps=self.config.rms_norm_eps)
|
eps=self.config.rms_norm_eps)
|
||||||
|
|
||||||
# Rotary embeddings.
|
|
||||||
self.rotary_emb = get_rope(
|
|
||||||
self.head_dim,
|
|
||||||
rotary_dim=self.head_dim,
|
|
||||||
max_position=self.max_position_embeddings,
|
|
||||||
base=self.rope_theta, # type: ignore
|
|
||||||
)
|
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
layer_idx = extract_layer_index(prefix)
|
||||||
|
sliding_window = None
|
||||||
|
if ((layer_types := getattr(self.config, "layer_types", None))
|
||||||
|
is not None and layer_types[layer_idx] == "sliding_attention"):
|
||||||
|
sliding_window = self.config.sliding_window
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -126,7 +127,20 @@ class Olmo2Attention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=vllm_config.cache_config,
|
cache_config=vllm_config.cache_config,
|
||||||
quant_config=vllm_config.quant_config,
|
quant_config=vllm_config.quant_config,
|
||||||
prefix=prefix,
|
per_layer_sliding_window=sliding_window,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rotary embeddings. Rope scaling is only applied on full attention
|
||||||
|
# layers.
|
||||||
|
self.rope_scaling = (self.config.rope_scaling
|
||||||
|
if sliding_window is None else None)
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=self.max_position_embeddings,
|
||||||
|
base=self.rope_theta, # type: ignore
|
||||||
|
rope_scaling=self.rope_scaling,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attention output projection.
|
# Attention output projection.
|
||||||
@ -176,7 +190,7 @@ class Olmo2MLP(nn.Module):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
assert isinstance(config, Olmo2Config)
|
assert isinstance(config, (Olmo2Config, Olmo3Config))
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
|
|
||||||
@ -221,7 +235,7 @@ class Olmo2DecoderLayer(nn.Module):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
assert isinstance(config, Olmo2Config)
|
assert isinstance(config, (Olmo2Config, Olmo3Config))
|
||||||
# Attention block.
|
# Attention block.
|
||||||
self.self_attn = Olmo2Attention(vllm_config=vllm_config,
|
self.self_attn = Olmo2Attention(vllm_config=vllm_config,
|
||||||
prefix=f"{prefix}.self_attn")
|
prefix=f"{prefix}.self_attn")
|
||||||
@ -261,7 +275,7 @@ class Olmo2Model(nn.Module):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = vllm_config.model_config.hf_config
|
self.config = vllm_config.model_config.hf_config
|
||||||
assert isinstance(self.config, Olmo2Config)
|
assert isinstance(self.config, (Olmo2Config, Olmo3Config))
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
self.config.vocab_size,
|
self.config.vocab_size,
|
||||||
@ -376,7 +390,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
assert isinstance(config, Olmo2Config)
|
assert isinstance(config, (Olmo2Config, Olmo3Config))
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = Olmo2Model(vllm_config=vllm_config,
|
self.model = Olmo2Model(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|||||||
@ -374,8 +374,8 @@ class Phi4MMAudioMeanVarianceNormLayer(nn.Module):
|
|||||||
Typically used as a very first layer in a model.
|
Typically used as a very first layer in a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_size: int
|
config: [Phi4MultimodalAudioConfig](https://huggingface.co/docs/transformers/model_doc/phi4_multimodal#transformers.Phi4MultimodalAudioConfig)
|
||||||
layer input size.
|
object containing model parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Phi4MultimodalAudioConfig):
|
def __init__(self, config: Phi4MultimodalAudioConfig):
|
||||||
|
|||||||
@ -1154,7 +1154,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# The result multimodal_embeddings is tuple of tensors, with each
|
# The result multimodal_embeddings is tuple of tensors, with each
|
||||||
# tensor correspoending to a multimodal data item (image or video).
|
# tensor corresponding to a multimodal data item (image or video).
|
||||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||||
|
|
||||||
# NOTE: It is important to iterate over the keys in this dictionary
|
# NOTE: It is important to iterate over the keys in this dictionary
|
||||||
|
|||||||
@ -1372,15 +1372,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
batch.
|
batch.
|
||||||
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
||||||
opensource models), the shape will be `(3, seq_len)`,
|
opensource models), the shape will be `(3, seq_len)`,
|
||||||
otherwise it will be `(seq_len,).
|
otherwise it will be `(seq_len,)`.
|
||||||
pixel_values: Pixel values to be fed to a model.
|
intermediate_tensors: Intermediate tensors from prior forward pass.
|
||||||
`None` if no images are passed.
|
inputs_embeds: Optional tensor of input embeddings.
|
||||||
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
|
||||||
`None` if no images are passed.
|
|
||||||
pixel_values_videos: Pixel values of videos to be fed to a model.
|
|
||||||
`None` if no videos are passed.
|
|
||||||
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
|
|
||||||
`None` if no videos are passed.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
|
|||||||
@ -170,8 +170,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
return quant_config
|
return quant_config
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
assert hidden_states.dim(
|
||||||
orig_shape = hidden_states.shape
|
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
|
||||||
|
is_input_1d = hidden_states.dim() == 1
|
||||||
hidden_dim = hidden_states.shape[-1]
|
hidden_dim = hidden_states.shape[-1]
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
@ -180,7 +181,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||||
router_logits=router_logits)
|
router_logits=router_logits)
|
||||||
|
|
||||||
return final_hidden_states.view(orig_shape)
|
# return to 1d if input is 1d
|
||||||
|
return final_hidden_states.squeeze(0) if is_input_1d else \
|
||||||
|
final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MoeAttention(nn.Module):
|
class Qwen3MoeAttention(nn.Module):
|
||||||
|
|||||||
@ -120,6 +120,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
|
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
|
||||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||||
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||||
|
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||||
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
||||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||||
|
|||||||
@ -390,12 +390,9 @@ class Siglip2EncoderLayer(nn.Module):
|
|||||||
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
|
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`):
|
hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
|
||||||
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
cu_seqlens: Cumulative sequence lengths tensor.
|
||||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
position_embeddings: Position embeddings tensor.
|
||||||
Whether or not to return the attentions tensors of all
|
|
||||||
attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
@ -534,19 +531,11 @@ class Siglip2Encoder(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
inputs_embeds (`torch.FloatTensor` of shape
|
inputs_embeds: Input tensor of shape
|
||||||
`(batch_size, sequence_length, hidden_size)`):
|
(batch_size, sequence_length, hidden_size).
|
||||||
Optionally, instead of passing `input_ids` you can choose to
|
Embedded representation of the input tokens.
|
||||||
directly pass an embedded representation. This is useful if
|
grid_thws: Grid tensor of shape (num_patches, 3)
|
||||||
you want more control over how to convert `input_ids` indices
|
containing grid dimensions.
|
||||||
into associated vectors than the model's internal embedding
|
|
||||||
lookup matrix.
|
|
||||||
grid_thws (`torch.LongTensor`):
|
|
||||||
grid shape (num_patches, 3)
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See
|
|
||||||
`hidden_states` under returned tensors for more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of
|
Whether or not to return a [`~utils.ModelOutput`] instead of
|
||||||
a plain tuple.
|
a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -597,10 +597,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
with the `input_ids`.
|
with the `input_ids`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_features: A batch of audio input chunks [B, N, 80, M].
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||||
audio_lens: Length of audio frames for each audio chunk [B].
|
batch.
|
||||||
audio_token_len: Length of audio tokens for each audio chunk [B'].
|
positions: Position indices for the input tokens.
|
||||||
Note: batch dim is different from batch dim in audio chunks.
|
intermediate_tensors: Intermediate tensors from prior forward pass.
|
||||||
|
inputs_embeds: Optional tensor of input embeddings.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -585,12 +585,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
r"language_model.model.layers.\1.mlp.down_proj"),
|
r"language_model.model.layers.\1.mlp.down_proj"),
|
||||||
(r"layers\.(\d+)\.feed_forward\.w3",
|
(r"layers\.(\d+)\.feed_forward\.w3",
|
||||||
r"language_model.model.layers.\1.mlp.up_proj"),
|
r"language_model.model.layers.\1.mlp.up_proj"),
|
||||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
|
|
||||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
|
|
||||||
),
|
|
||||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
|
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
|
||||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
|
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
|
||||||
),
|
),
|
||||||
|
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
|
||||||
|
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
|
||||||
|
),
|
||||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
|
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
|
||||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
|
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
|
||||||
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
|
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
|
||||||
|
|||||||
@ -909,8 +909,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
prefix: Optional prefix for parameter names
|
prefix: Optional prefix for parameter names
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If prefix caching is enabled
|
AssertionError: If prefix caching is enabled
|
||||||
(not supported by Mamba)
|
(not supported by Mamba)
|
||||||
"""
|
"""
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
|
|||||||
@ -85,9 +85,10 @@ which are treated as audio embeddings;
|
|||||||
these are directly passed to the model without HF processing.
|
these are directly passed to the model without HF processing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ModalityData: TypeAlias = Union[_T, list[_T]]
|
ModalityData: TypeAlias = Union[_T, list[Optional[_T]], None]
|
||||||
"""
|
"""
|
||||||
Either a single data item, or a list of data items.
|
Either a single data item, or a list of data items. Can only be None if UUID
|
||||||
|
is provided.
|
||||||
|
|
||||||
The number of data items allowed per modality is restricted by
|
The number of data items allowed per modality is restricted by
|
||||||
`--limit-mm-per-prompt`.
|
`--limit-mm-per-prompt`.
|
||||||
|
|||||||
@ -36,7 +36,7 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
|
|||||||
def __init__(self, data: _T, modality: str) -> None:
|
def __init__(self, data: _T, modality: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.data = data
|
self.data: _T = data
|
||||||
self.modality = modality
|
self.modality = modality
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -177,7 +177,9 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
|
|||||||
|
|
||||||
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
||||||
|
|
||||||
def __init__(self, data: Sequence[HfAudioItem]) -> None:
|
def __init__(self, data: Optional[Sequence[HfAudioItem]]) -> None:
|
||||||
|
if data is None:
|
||||||
|
data = [None]
|
||||||
super().__init__(data, "audio")
|
super().__init__(data, "audio")
|
||||||
|
|
||||||
def get_audio_length(self, item_idx: int) -> int:
|
def get_audio_length(self, item_idx: int) -> int:
|
||||||
@ -198,7 +200,9 @@ class ImageSize(NamedTuple):
|
|||||||
|
|
||||||
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
||||||
|
|
||||||
def __init__(self, data: Sequence[HfImageItem]) -> None:
|
def __init__(self, data: Optional[Sequence[HfImageItem]]) -> None:
|
||||||
|
if data is None:
|
||||||
|
data = [None]
|
||||||
super().__init__(data, "image")
|
super().__init__(data, "image")
|
||||||
|
|
||||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||||
@ -223,10 +227,12 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data: Sequence[HfVideoItem],
|
data: Optional[Sequence[HfVideoItem]],
|
||||||
metadata: Optional[Union[dict[str, Any],
|
metadata: Optional[Union[dict[str, Any],
|
||||||
list[Optional[dict[str, Any]]]]] = None,
|
list[Optional[dict[str, Any]]]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if data is None:
|
||||||
|
data = [None]
|
||||||
super().__init__(data, "video")
|
super().__init__(data, "video")
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
|
||||||
@ -385,6 +391,9 @@ class MultiModalDataParser:
|
|||||||
self,
|
self,
|
||||||
data: ModalityData[AudioItem],
|
data: ModalityData[AudioItem],
|
||||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||||
|
if data is None:
|
||||||
|
return AudioProcessorItems(None)
|
||||||
|
|
||||||
# also check single audio item with sampling rate
|
# also check single audio item with sampling rate
|
||||||
if self._is_empty(data) or (isinstance(data, tuple)
|
if self._is_empty(data) or (isinstance(data, tuple)
|
||||||
and self._is_empty(data[0])):
|
and self._is_empty(data[0])):
|
||||||
@ -420,6 +429,9 @@ class MultiModalDataParser:
|
|||||||
self,
|
self,
|
||||||
data: ModalityData[ImageItem],
|
data: ModalityData[ImageItem],
|
||||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||||
|
if data is None:
|
||||||
|
return ImageProcessorItems(None)
|
||||||
|
|
||||||
if self._is_empty(data):
|
if self._is_empty(data):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -441,6 +453,9 @@ class MultiModalDataParser:
|
|||||||
self,
|
self,
|
||||||
data: ModalityData[VideoItem],
|
data: ModalityData[VideoItem],
|
||||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||||
|
if data is None:
|
||||||
|
return VideoProcessorItems(None)
|
||||||
|
|
||||||
if self._is_empty(data):
|
if self._is_empty(data):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -1075,7 +1075,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
|
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
|
||||||
"""
|
"""
|
||||||
mm_items = self.data_parser.parse_mm_data(mm_data)
|
mm_items = self.data_parser.parse_mm_data(mm_data)
|
||||||
|
|
||||||
for modality, items in mm_items.items():
|
for modality, items in mm_items.items():
|
||||||
self.validate_num_items(modality, len(items))
|
self.validate_num_items(modality, len(items))
|
||||||
|
|
||||||
@ -1436,10 +1435,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
]
|
]
|
||||||
for modality, items_is_cached in mm_is_cached.items()
|
for modality, items_is_cached in mm_is_cached.items()
|
||||||
}
|
}
|
||||||
mm_missing_data = {
|
mm_missing_data = {}
|
||||||
modality: [mm_data_items[modality][idx] for idx in idxs]
|
for modality, idxs in mm_missing_idxs.items():
|
||||||
for modality, idxs in mm_missing_idxs.items()
|
missing_modality_data = []
|
||||||
}
|
for idx in idxs:
|
||||||
|
data = mm_data_items[modality][idx]
|
||||||
|
if data is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cache miss for {modality} at index {idx} "
|
||||||
|
f"but data is not provided.")
|
||||||
|
else:
|
||||||
|
missing_modality_data.append(data)
|
||||||
|
mm_missing_data[modality] = missing_modality_data
|
||||||
|
|
||||||
return self._to_mm_items(mm_missing_data)
|
return self._to_mm_items(mm_missing_data)
|
||||||
|
|
||||||
|
|||||||
@ -179,6 +179,7 @@ class CudaPlatformBase(Platform):
|
|||||||
cache_config.block_size = 128
|
cache_config.block_size = 128
|
||||||
logger.info("Forcing kv cache block size to 128 for "
|
logger.info("Forcing kv cache block size to 128 for "
|
||||||
"CUTLASS_MLA backend.")
|
"CUTLASS_MLA backend.")
|
||||||
|
|
||||||
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
|
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
|
||||||
cache_config.block_size = 64
|
cache_config.block_size = 64
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -541,7 +542,9 @@ class CudaPlatformBase(Platform):
|
|||||||
attention_backend = "FLASHMLA"
|
attention_backend = "FLASHMLA"
|
||||||
|
|
||||||
# Only FlashMLA and CUTLASS_MLA support fp8
|
# Only FlashMLA and CUTLASS_MLA support fp8
|
||||||
if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]:
|
if attention_backend in [
|
||||||
|
"FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"
|
||||||
|
]:
|
||||||
supported = True
|
supported = True
|
||||||
else:
|
else:
|
||||||
supported = (not fp8_attention)
|
supported = (not fp8_attention)
|
||||||
|
|||||||
@ -75,6 +75,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
|||||||
eagle="EAGLEConfig",
|
eagle="EAGLEConfig",
|
||||||
speculators="SpeculatorsConfig",
|
speculators="SpeculatorsConfig",
|
||||||
nemotron="NemotronConfig",
|
nemotron="NemotronConfig",
|
||||||
|
olmo3="Olmo3Config",
|
||||||
ovis="OvisConfig",
|
ovis="OvisConfig",
|
||||||
ultravox="UltravoxConfig",
|
ultravox="UltravoxConfig",
|
||||||
step3_vl="Step3VLConfig",
|
step3_vl="Step3VLConfig",
|
||||||
@ -678,20 +679,21 @@ def get_hf_file_to_dict(file_name: str,
|
|||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
def get_pooling_config(model: str,
|
||||||
|
revision: Optional[str] = 'main') -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
This function gets the pooling and normalize
|
This function gets the pooling and normalize
|
||||||
config from the model - only applies to
|
config from the model - only applies to
|
||||||
sentence-transformers models.
|
sentence-transformers models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (str): The name of the Hugging Face model.
|
model: The name of the Hugging Face model.
|
||||||
revision (str, optional): The specific version
|
revision: The specific version of the model to use.
|
||||||
of the model to use. Defaults to 'main'.
|
Defaults to 'main'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the pooling
|
A dictionary containing the pooling type and whether
|
||||||
type and whether normalization is used.
|
normalization is used, or None if no pooling configuration is found.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
modules_file_name = "modules.json"
|
modules_file_name = "modules.json"
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
|||||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||||
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
||||||
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
||||||
|
from vllm.transformers_utils.configs.olmo3 import Olmo3Config
|
||||||
from vllm.transformers_utils.configs.ovis import OvisConfig
|
from vllm.transformers_utils.configs.ovis import OvisConfig
|
||||||
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
|
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
|
||||||
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
|
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
|
||||||
@ -45,6 +46,7 @@ __all__ = [
|
|||||||
"NemotronConfig",
|
"NemotronConfig",
|
||||||
"NemotronHConfig",
|
"NemotronHConfig",
|
||||||
"Nemotron_Nano_VL_Config",
|
"Nemotron_Nano_VL_Config",
|
||||||
|
"Olmo3Config",
|
||||||
"OvisConfig",
|
"OvisConfig",
|
||||||
"SpeculatorsConfig",
|
"SpeculatorsConfig",
|
||||||
"UltravoxConfig",
|
"UltravoxConfig",
|
||||||
|
|||||||
@ -74,10 +74,10 @@ class JAISConfig(PretrainedConfig):
|
|||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not the model should return the last key/values
|
Whether or not the model should return the last key/values
|
||||||
attentions (not used by all models).
|
attentions (not used by all models).
|
||||||
scale_attn_by_inverse_layer_idx (`bool`, *optional*,
|
scale_attn_by_inverse_layer_idx
|
||||||
defaults to `False`):
|
(`bool`, *optional*, defaults to `False`):
|
||||||
Whether to additionally scale attention weights by
|
Whether to additionally scale attention weights
|
||||||
`1 / layer_idx + 1`.
|
by `1 / layer_idx + 1`.
|
||||||
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
|
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to scale keys (K) prior to computing attention
|
Whether to scale keys (K) prior to computing attention
|
||||||
(dot-product)
|
(dot-product)
|
||||||
|
|||||||
80
vllm/transformers_utils/configs/olmo3.py
Normal file
80
vllm/transformers_utils/configs/olmo3.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Olmo3Config(PretrainedConfig):
|
||||||
|
|
||||||
|
model_type = "olmo3"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=50304,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=None,
|
||||||
|
eos_token_id=50279,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
sliding_window=4096,
|
||||||
|
layer_types=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM
|
||||||
|
# in vLLM.
|
||||||
|
if "architectures" not in kwargs:
|
||||||
|
kwargs["architectures"] = ["Olmo2ForCausalLM"]
|
||||||
|
elif "Olmo3ForCausalLM" in kwargs["architectures"]:
|
||||||
|
kwargs["architectures"].remove("Olmo3ForCausalLM")
|
||||||
|
kwargs["architectures"].append("Olmo2ForCausalLM")
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.layer_types = layer_types
|
||||||
|
if self.layer_types is None:
|
||||||
|
self.layer_types = [
|
||||||
|
"sliding_attention" if (i + 1) % 4 != 0 else "full_attention"
|
||||||
|
for i in range(self.num_hidden_layers)
|
||||||
|
]
|
||||||
@ -37,10 +37,6 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
|||||||
The initialization value for the layer normalization.
|
The initialization value for the layer normalization.
|
||||||
projector_act (`str`, *optional*, defaults to `"swiglu"`):
|
projector_act (`str`, *optional*, defaults to `"swiglu"`):
|
||||||
The activation function used by the multimodal projector.
|
The activation function used by the multimodal projector.
|
||||||
text_model_lora_config (`LoraConfigSimplified`, *optional*):
|
|
||||||
The LoRA configuration for finetuning the text model.
|
|
||||||
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
|
|
||||||
The LoRA configuration for finetuning the audio model.
|
|
||||||
projector_ln_mid (`bool`, *optional*, defaults to `False`):
|
projector_ln_mid (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to apply layer normalization at the middle of the
|
Whether to apply layer normalization at the middle of the
|
||||||
projector or at the end. Versions v0.4.1 and below
|
projector or at the end. Versions v0.4.1 and below
|
||||||
|
|||||||
@ -25,6 +25,7 @@
|
|||||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
@ -178,17 +179,15 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
images: list[Image.Image],
|
images: list[Image.Image],
|
||||||
inference_mode: bool = True,
|
inference_mode: bool = True,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (str): the formatted prompt;
|
prompt (str): the formatted prompt;
|
||||||
conversations (list[dict]): conversations with a list of messages;
|
|
||||||
images (list[ImageType]): the list of images;
|
images (list[ImageType]): the list of images;
|
||||||
inference_mode (bool): if True, then remove the last eos token;
|
inference_mode (bool): if True, then remove the last eos token;
|
||||||
system_prompt (str): the system prompt;
|
**kwargs: Additional keyword arguments.
|
||||||
**kwargs:
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
outputs (BaseProcessorOutput): the output of the processor,
|
outputs (BaseProcessorOutput): the output of the processor,
|
||||||
@ -259,7 +258,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|||||||
text: str,
|
text: str,
|
||||||
images: list[Image.Image],
|
images: list[Image.Image],
|
||||||
inference_mode: bool = True,
|
inference_mode: bool = True,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,6 @@ def list_safetensors(path: str = "") -> list[str]:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: The object storage path to list from.
|
path: The object storage path to list from.
|
||||||
allow_pattern: A list of patterns of which files to pull.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[str]: List of full object storage paths allowed by the pattern
|
list[str]: List of full object storage paths allowed by the pattern
|
||||||
@ -54,8 +53,7 @@ class ObjectStorageModel:
|
|||||||
dir: The temporary created directory.
|
dir: The temporary created directory.
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
pull_files(): Pull model from object storage to the temporary
|
pull_files(): Pull model from object storage to the temporary directory.
|
||||||
directory.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from vllm.utils import PlaceholderModule
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def glob(s3=None,
|
def glob(s3: Optional[Any] = None,
|
||||||
path: str = "",
|
path: str = "",
|
||||||
allow_pattern: Optional[list[str]] = None) -> list[str]:
|
allow_pattern: Optional[list[str]] = None) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@ -51,7 +51,7 @@ def glob(s3=None,
|
|||||||
|
|
||||||
|
|
||||||
def list_files(
|
def list_files(
|
||||||
s3,
|
s3: Any,
|
||||||
path: str,
|
path: str,
|
||||||
allow_pattern: Optional[list[str]] = None,
|
allow_pattern: Optional[list[str]] = None,
|
||||||
ignore_pattern: Optional[list[str]] = None
|
ignore_pattern: Optional[list[str]] = None
|
||||||
|
|||||||
@ -2082,6 +2082,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
|||||||
return await task(*args, **kwargs)
|
return await task(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
def supports_kw(
|
def supports_kw(
|
||||||
callable: Callable[..., object],
|
callable: Callable[..., object],
|
||||||
kw_name: str,
|
kw_name: str,
|
||||||
|
|||||||
@ -209,7 +209,8 @@ class GDNAttentionMetadataBuilder(
|
|||||||
|
|
||||||
# prepare tensors for cudagraph
|
# prepare tensors for cudagraph
|
||||||
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
|
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
|
||||||
and num_spec_decodes <= self.decode_cudagraph_max_bs):
|
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||||
|
and m.num_actual_tokens <= self.decode_cudagraph_max_bs):
|
||||||
num_total_tokens = self.vllm_config.pad_for_cudagraph(
|
num_total_tokens = self.vllm_config.pad_for_cudagraph(
|
||||||
m.num_actual_tokens)
|
m.num_actual_tokens)
|
||||||
batch_size = num_total_tokens // (self.num_spec + 1)
|
batch_size = num_total_tokens // (self.num_spec + 1)
|
||||||
|
|||||||
@ -584,7 +584,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
window_left=self._global_hyperparameters.window_left,
|
window_left=self._global_hyperparameters.window_left,
|
||||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||||
q_data_type=self.model_config.dtype,
|
q_data_type=self.model_config.dtype,
|
||||||
kv_data_type=self.kv_cache_spec.dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare context prefills
|
# Prepare context prefills
|
||||||
@ -605,7 +604,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
logits_soft_cap=self._global_hyperparameters.
|
logits_soft_cap=self._global_hyperparameters.
|
||||||
logits_soft_cap,
|
logits_soft_cap,
|
||||||
q_data_type=self.model_config.dtype,
|
q_data_type=self.model_config.dtype,
|
||||||
kv_data_type=self.kv_cache_spec.dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill.prefill_main = self._fi_prefill_main
|
prefill.prefill_main = self._fi_prefill_main
|
||||||
|
|||||||
@ -6,8 +6,7 @@ from typing import Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||||
is_quantized_kv_cache)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
@ -69,11 +68,9 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"FlashInferMLAImpl")
|
"FlashInferMLAImpl")
|
||||||
|
|
||||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"FlashInferMLA V1 with FP8 KV cache not yet supported")
|
|
||||||
|
|
||||||
self._workspace_buffer = g_fi_workspace
|
self._workspace_buffer = g_fi_workspace
|
||||||
|
self.bmm1_scale: Optional[float] = None
|
||||||
|
self.bmm2_scale: Optional[float] = None
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
@ -92,6 +89,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
# trtllm API requires extra dimension q_len_per_request for MTP
|
# trtllm API requires extra dimension q_len_per_request for MTP
|
||||||
q = q.unsqueeze(1)
|
q = q.unsqueeze(1)
|
||||||
|
|
||||||
|
if self.bmm1_scale is None:
|
||||||
|
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
|
||||||
|
self.scale)
|
||||||
|
if self.bmm2_scale is None:
|
||||||
|
self.bmm2_scale = layer._v_scale_float
|
||||||
|
|
||||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||||
query=q,
|
query=q,
|
||||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||||
@ -102,7 +105,8 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
block_tables=attn_metadata.decode.block_table,
|
block_tables=attn_metadata.decode.block_table,
|
||||||
seq_lens=attn_metadata.decode.seq_lens,
|
seq_lens=attn_metadata.decode.seq_lens,
|
||||||
max_seq_len=attn_metadata.max_seq_len,
|
max_seq_len=attn_metadata.max_seq_len,
|
||||||
bmm1_scale=self.scale,
|
bmm1_scale=self.bmm1_scale,
|
||||||
|
bmm2_scale=self.bmm2_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Return LSE pending support from Flashinfer API:
|
# TODO: Return LSE pending support from Flashinfer API:
|
||||||
|
|||||||
@ -159,6 +159,9 @@ class EngineCore:
|
|||||||
self.request_block_hasher = get_request_block_hasher(
|
self.request_block_hasher = get_request_block_hasher(
|
||||||
block_size, caching_hash_fn)
|
block_size, caching_hash_fn)
|
||||||
|
|
||||||
|
self.step_fn = (self.step if self.batch_queue is None else
|
||||||
|
self.step_with_batch_queue)
|
||||||
|
|
||||||
def _initialize_kv_caches(
|
def _initialize_kv_caches(
|
||||||
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -331,7 +334,8 @@ class EngineCore:
|
|||||||
model_executed = False
|
model_executed = False
|
||||||
if self.scheduler.has_requests():
|
if self.scheduler.has_requests():
|
||||||
scheduler_output = self.scheduler.schedule()
|
scheduler_output = self.scheduler.schedule()
|
||||||
future = self.model_executor.execute_model(scheduler_output)
|
future = self.model_executor.execute_model(scheduler_output,
|
||||||
|
non_block=True)
|
||||||
batch_queue.appendleft(
|
batch_queue.appendleft(
|
||||||
(future, scheduler_output)) # type: ignore[arg-type]
|
(future, scheduler_output)) # type: ignore[arg-type]
|
||||||
|
|
||||||
@ -534,9 +538,6 @@ class EngineCoreProc(EngineCore):
|
|||||||
assert addresses.coordinator_input is not None
|
assert addresses.coordinator_input is not None
|
||||||
logger.info("Waiting for READY message from DP Coordinator...")
|
logger.info("Waiting for READY message from DP Coordinator...")
|
||||||
|
|
||||||
self.step_fn = (self.step if self.batch_queue is None else
|
|
||||||
self.step_with_batch_queue)
|
|
||||||
|
|
||||||
# Mark the startup heap as static so that it's ignored by GC.
|
# Mark the startup heap as static so that it's ignored by GC.
|
||||||
# Reduces pause times of oldest generation collections.
|
# Reduces pause times of oldest generation collections.
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user