mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 14:37:08 +08:00
Merge branch 'main' into Add_support_for_openpangu_promoe_v2
This commit is contained in:
commit
b4dce6858a
@ -73,12 +73,11 @@ function cpu_tests() {
|
||||
pytest -x -s -v \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs"
|
||||
|
||||
# Note: disable it until supports V1
|
||||
# Run AWQ test
|
||||
# docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
# set -e
|
||||
# pytest -x -s -v \
|
||||
# tests/quantization/test_ipex_quant.py"
|
||||
# Run AWQ/GPTQ test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -x -s -v \
|
||||
tests/quantization/test_cpu_wna16.py"
|
||||
|
||||
# Run multi-lora tests
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
|
||||
@ -1068,7 +1068,7 @@ steps:
|
||||
# this runner has 2 GPUs available even though num_gpus=2 is not set
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
|
||||
# Wrap with quotes to escape yaml
|
||||
# Wrap with quotes to escape yaml
|
||||
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
|
||||
|
||||
- label: Blackwell Fusion E2E Tests # 30 min
|
||||
@ -1095,10 +1095,11 @@ steps:
|
||||
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
||||
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
|
||||
|
||||
- label: Blackwell GPT-OSS Eval
|
||||
- label: ROCm GPT-OSS Eval
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
agent_pool: mi325_1
|
||||
mirror_hardwares: [amdproduction]
|
||||
optional: true # run on nightlies
|
||||
source_file_dependencies:
|
||||
- tests/evals/gpt_oss
|
||||
@ -1107,7 +1108,7 @@ steps:
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
commands:
|
||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
|
||||
- VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
|
||||
|
||||
- label: Blackwell Quantized MoE Test
|
||||
timeout_in_minutes: 60
|
||||
|
||||
@ -478,10 +478,11 @@ steps:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
# fp8 kv scales not supported on sm89, tested on Blackwell instead
|
||||
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
|
||||
# Limit to no custom ops to reduce running time
|
||||
# Wrap with quotes to escape yaml and avoid starting -k string with a -
|
||||
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
|
||||
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
|
||||
|
||||
- label: Cudagraph test
|
||||
timeout_in_minutes: 20
|
||||
@ -925,7 +926,7 @@ steps:
|
||||
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
|
||||
- label: Blackwell Fusion Tests # 30 min
|
||||
- label: Blackwell Fusion and Compile Tests # 30 min
|
||||
timeout_in_minutes: 40
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
@ -946,7 +947,9 @@ steps:
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
|
||||
# Wrap with quotes to escape yaml
|
||||
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
|
||||
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
|
||||
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
||||
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
|
||||
|
||||
- label: Blackwell Fusion E2E Tests # 30 min
|
||||
timeout_in_minutes: 40
|
||||
@ -969,8 +972,6 @@ steps:
|
||||
- nvidia-smi
|
||||
# Run all e2e fusion tests
|
||||
- pytest -v -s tests/compile/test_fusions_e2e.py
|
||||
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
||||
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
|
||||
|
||||
- label: Blackwell GPT-OSS Eval
|
||||
timeout_in_minutes: 60
|
||||
@ -1266,7 +1267,8 @@ steps:
|
||||
- pytest -v -s tests/compile/test_async_tp.py
|
||||
- pytest -v -s tests/compile/test_sequence_parallelism.py
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
|
||||
- "pytest -v -s tests/compile/test_fusions_e2e.py -k 'not Llama-4'"
|
||||
- pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
15
.github/workflows/macos-smoke-test.yml
vendored
15
.github/workflows/macos-smoke-test.yml
vendored
@ -1,6 +1,9 @@
|
||||
name: macOS Apple Silicon Smoke Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch: # Manual trigger
|
||||
|
||||
jobs:
|
||||
@ -19,13 +22,15 @@ jobs:
|
||||
pyproject.toml
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Create virtual environment
|
||||
run: |
|
||||
uv pip install -r requirements/cpu-build.txt
|
||||
uv pip install -r requirements/cpu.txt
|
||||
uv venv
|
||||
echo "$GITHUB_WORKSPACE/.venv/bin" >> "$GITHUB_PATH"
|
||||
|
||||
- name: Build vLLM
|
||||
run: uv pip install -v -e .
|
||||
- name: Install dependencies and build vLLM
|
||||
run: |
|
||||
uv pip install -r requirements/cpu.txt --index-strategy unsafe-best-match
|
||||
uv pip install -e .
|
||||
env:
|
||||
CMAKE_BUILD_PARALLEL_LEVEL: 4
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -4,6 +4,9 @@
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
|
||||
# OpenAI triton kernels copied from source
|
||||
vllm/third_party/triton_kernels/*
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
|
||||
@ -512,9 +512,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
@ -619,9 +619,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# FP4 Archs and flags
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
@ -695,7 +695,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
|
||||
@ -741,9 +741,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu")
|
||||
@ -861,7 +861,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
# Hadacore kernels
|
||||
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
if(HADACORE_ARCHS)
|
||||
set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -1030,6 +1030,11 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
WITH_SOABI)
|
||||
endif()
|
||||
|
||||
# For CUDA and HIP builds also build the triton_kernels external package.
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
include(cmake/external_projects/triton_kernels.cmake)
|
||||
endif()
|
||||
|
||||
# For CUDA we also build and ship some external projects.
|
||||
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
include(cmake/external_projects/flashmla.cmake)
|
||||
|
||||
380
benchmarks/benchmark_batch_invariance.py
Executable file
380
benchmarks/benchmark_batch_invariance.py
Executable file
@ -0,0 +1,380 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark to measure the performance overhead of VLLM_BATCH_INVARIANT mode.
|
||||
|
||||
This benchmark runs the same workload twice:
|
||||
1. With VLLM_BATCH_INVARIANT=0 (baseline)
|
||||
2. With VLLM_BATCH_INVARIANT=1 (batch invariant mode)
|
||||
|
||||
And reports the timing and throughput metrics for comparison.
|
||||
|
||||
Environment variables:
|
||||
VLLM_BENCH_MODEL: Model to benchmark (default: "Qwen/Qwen3-1.7B")
|
||||
VLLM_BENCH_TP_SIZE: Tensor parallel size (default: 1, use 8 for deepseek)
|
||||
VLLM_BENCH_BATCH_SIZE: Max batch size (default: 128)
|
||||
VLLM_BENCH_NUM_TRIALS: Number of trials to run (default: 5)
|
||||
VLLM_BENCH_MIN_PROMPT: Min prompt length in words (default: 1024)
|
||||
VLLM_BENCH_MAX_PROMPT: Max prompt length in words (default: 2048)
|
||||
VLLM_BENCH_MAX_TOKENS: Max tokens to generate (default: 128)
|
||||
VLLM_BENCH_TEMPERATURE: Temperature for sampling (default: 0.0)
|
||||
VLLM_BENCH_GPU_MEMORY_UTILIZATION: GPU memory utilization (default: 0.4)
|
||||
VLLM_BENCH_MAX_MODEL_LEN: Max model length (default: 5120)
|
||||
VLLM_BENCH_BACKEND: Attention backend (default: FLASH_ATTN)
|
||||
|
||||
Example usage:
|
||||
# Benchmark qwen3 (default)
|
||||
python benchmarks/benchmark_batch_invariance.py
|
||||
|
||||
# Benchmark deepseek with 8 GPUs
|
||||
VLLM_BENCH_MODEL="deepseek-ai/DeepSeek-V3" VLLM_BENCH_TP_SIZE=8 \\
|
||||
python benchmarks/benchmark_batch_invariance.py
|
||||
|
||||
# Quick test with fewer trials
|
||||
VLLM_BENCH_NUM_TRIALS=2 VLLM_BENCH_BATCH_SIZE=32 \\
|
||||
python benchmarks/benchmark_batch_invariance.py
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
"""Generate a random prompt for benchmarking."""
|
||||
prompt_templates = [
|
||||
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
||||
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
||||
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
||||
"Once upon a time in a distant galaxy, there lived",
|
||||
"The old man walked slowly down the street, remembering",
|
||||
"In the year 2157, humanity finally discovered",
|
||||
"To implement a binary search tree in Python, first we need to",
|
||||
"The algorithm works by iterating through the array and",
|
||||
"Here's how to optimize database queries using indexing:",
|
||||
"The Renaissance was a period in European history that",
|
||||
"Climate change is caused by several factors including",
|
||||
"The human brain contains approximately 86 billion neurons which",
|
||||
"I've been thinking about getting a new laptop because",
|
||||
"Yesterday I went to the store and bought",
|
||||
"My favorite thing about summer is definitely",
|
||||
]
|
||||
|
||||
base_prompt = random.choice(prompt_templates)
|
||||
|
||||
if max_words < min_words:
|
||||
max_words = min_words
|
||||
target_words = random.randint(min_words, max_words)
|
||||
|
||||
if target_words > 50:
|
||||
padding_text = (
|
||||
" This is an interesting topic that deserves more explanation. "
|
||||
* (target_words // 50)
|
||||
)
|
||||
base_prompt = base_prompt + padding_text
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
def run_benchmark_with_batch_invariant(
|
||||
model: str,
|
||||
tp_size: int,
|
||||
max_batch_size: int,
|
||||
num_trials: int,
|
||||
min_prompt: int,
|
||||
max_prompt: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
gpu_mem_util: float,
|
||||
max_model_len: int,
|
||||
backend: str,
|
||||
batch_invariant: bool,
|
||||
seed: int = 12345,
|
||||
) -> dict:
|
||||
"""
|
||||
Run the benchmark with the specified configuration.
|
||||
|
||||
Returns a dict with timing and throughput metrics.
|
||||
"""
|
||||
random.seed(seed)
|
||||
|
||||
# Set environment variables
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
if batch_invariant:
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
else:
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "0"
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"BENCHMARK: VLLM_BATCH_INVARIANT={int(batch_invariant)}")
|
||||
print(f" Model: {model}")
|
||||
print(f" TP Size: {tp_size}")
|
||||
print(f" Backend: {backend}")
|
||||
print(f" Max Batch Size: {max_batch_size}")
|
||||
print(f" Trials: {num_trials}")
|
||||
print(f" Max Tokens: {max_tokens}")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
sampling = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=0.95,
|
||||
max_tokens=max_tokens,
|
||||
seed=20240919,
|
||||
)
|
||||
|
||||
needle_prompt = "There once was a "
|
||||
|
||||
llm = None
|
||||
try:
|
||||
# Create LLM engine
|
||||
start_init = time.perf_counter()
|
||||
llm = LLM(
|
||||
model=model,
|
||||
max_num_seqs=max_batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
)
|
||||
init_time = time.perf_counter() - start_init
|
||||
print(f"Engine initialization time: {init_time:.2f}s\n")
|
||||
|
||||
# Generate baseline
|
||||
print("Generating baseline (warmup)...")
|
||||
baseline_out = llm.generate([needle_prompt], sampling)
|
||||
assert len(baseline_out) == 1
|
||||
baseline_text = baseline_out[0].outputs[0].text
|
||||
print(f"Baseline output: '{baseline_text[:50]}...'\n")
|
||||
|
||||
# Run trials and measure timing
|
||||
trial_times: list[float] = []
|
||||
total_tokens = 0
|
||||
total_prompts = 0
|
||||
|
||||
for trial in range(num_trials):
|
||||
# Create a batch
|
||||
prompts: list[str] = []
|
||||
batch_size = random.randint(max_batch_size // 2, max_batch_size)
|
||||
needle_pos = random.randint(0, batch_size - 1)
|
||||
for i in range(batch_size):
|
||||
if i == needle_pos:
|
||||
prompts.append(needle_prompt)
|
||||
else:
|
||||
prompts.append(_random_prompt(min_prompt, max_prompt))
|
||||
|
||||
# Measure time for this trial
|
||||
start_time = time.perf_counter()
|
||||
outputs = llm.generate(prompts, sampling)
|
||||
trial_time = time.perf_counter() - start_time
|
||||
|
||||
trial_times.append(trial_time)
|
||||
total_prompts += len(prompts)
|
||||
|
||||
# Count tokens
|
||||
for output in outputs:
|
||||
if output.outputs:
|
||||
total_tokens += len(output.outputs[0].token_ids)
|
||||
|
||||
print(
|
||||
f"Trial {trial + 1}/{num_trials}: "
|
||||
f"batch_size={batch_size}, "
|
||||
f"time={trial_time:.2f}s"
|
||||
)
|
||||
|
||||
# Verify needle output still matches
|
||||
needle_output = outputs[needle_pos]
|
||||
assert needle_output.prompt == needle_prompt
|
||||
|
||||
# Compute statistics
|
||||
avg_time = sum(trial_times) / len(trial_times)
|
||||
min_time = min(trial_times)
|
||||
max_time = max(trial_times)
|
||||
throughput = total_tokens / sum(trial_times)
|
||||
prompts_per_sec = total_prompts / sum(trial_times)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("RESULTS:")
|
||||
print(f" Average time per trial: {avg_time:.2f}s")
|
||||
print(f" Min time: {min_time:.2f}s")
|
||||
print(f" Max time: {max_time:.2f}s")
|
||||
print(f" Total tokens generated: {total_tokens}")
|
||||
print(f" Total prompts processed: {total_prompts}")
|
||||
print(f" Throughput: {throughput:.2f} tokens/s")
|
||||
print(f" Prompts/s: {prompts_per_sec:.2f}")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
return {
|
||||
"init_time": init_time,
|
||||
"avg_time": avg_time,
|
||||
"min_time": min_time,
|
||||
"max_time": max_time,
|
||||
"total_tokens": total_tokens,
|
||||
"total_prompts": total_prompts,
|
||||
"throughput": throughput,
|
||||
"prompts_per_sec": prompts_per_sec,
|
||||
"trial_times": trial_times,
|
||||
}
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if llm is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
def main():
|
||||
# Check platform support
|
||||
if not (current_platform.is_cuda() and current_platform.has_device_capability(90)):
|
||||
print("ERROR: Requires CUDA and >= Hopper (SM90)")
|
||||
print(f"Current platform: {current_platform.device_type}")
|
||||
if current_platform.is_cuda():
|
||||
print(f"Device capability: {current_platform.get_device_capability()}")
|
||||
return 1
|
||||
|
||||
# Read configuration from environment
|
||||
model = os.getenv("VLLM_BENCH_MODEL", "Qwen/Qwen3-1.7B")
|
||||
tp_size = int(os.getenv("VLLM_BENCH_TP_SIZE", "1"))
|
||||
max_batch_size = int(os.getenv("VLLM_BENCH_BATCH_SIZE", "128"))
|
||||
num_trials = int(os.getenv("VLLM_BENCH_NUM_TRIALS", "5"))
|
||||
min_prompt = int(os.getenv("VLLM_BENCH_MIN_PROMPT", "1024"))
|
||||
max_prompt = int(os.getenv("VLLM_BENCH_MAX_PROMPT", "2048"))
|
||||
max_tokens = int(os.getenv("VLLM_BENCH_MAX_TOKENS", "128"))
|
||||
temperature = float(os.getenv("VLLM_BENCH_TEMPERATURE", "0.0"))
|
||||
gpu_mem_util = float(os.getenv("VLLM_BENCH_GPU_MEMORY_UTILIZATION", "0.4"))
|
||||
max_model_len = int(os.getenv("VLLM_BENCH_MAX_MODEL_LEN", "5120"))
|
||||
backend = os.getenv("VLLM_BENCH_BACKEND", "FLASH_ATTN")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("VLLM BATCH INVARIANCE BENCHMARK")
|
||||
print("=" * 80)
|
||||
print("\nConfiguration:")
|
||||
print(f" Model: {model}")
|
||||
print(f" Tensor Parallel Size: {tp_size}")
|
||||
print(f" Attention Backend: {backend}")
|
||||
print(f" Max Batch Size: {max_batch_size}")
|
||||
print(f" Number of Trials: {num_trials}")
|
||||
print(f" Prompt Length Range: {min_prompt}-{max_prompt} words")
|
||||
print(f" Max Tokens to Generate: {max_tokens}")
|
||||
print(f" Temperature: {temperature}")
|
||||
print(f" GPU Memory Utilization: {gpu_mem_util}")
|
||||
print(f" Max Model Length: {max_model_len}")
|
||||
print("=" * 80)
|
||||
|
||||
# Run benchmark WITHOUT batch invariance (baseline)
|
||||
print("\n" + "=" * 80)
|
||||
print("PHASE 1: Running WITHOUT batch invariance (baseline)")
|
||||
print("=" * 80)
|
||||
baseline_results = run_benchmark_with_batch_invariant(
|
||||
model=model,
|
||||
tp_size=tp_size,
|
||||
max_batch_size=max_batch_size,
|
||||
num_trials=num_trials,
|
||||
min_prompt=min_prompt,
|
||||
max_prompt=max_prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
gpu_mem_util=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
backend=backend,
|
||||
batch_invariant=False,
|
||||
)
|
||||
|
||||
# Run benchmark WITH batch invariance
|
||||
print("\n" + "=" * 80)
|
||||
print("PHASE 2: Running WITH batch invariance")
|
||||
print("=" * 80)
|
||||
batch_inv_results = run_benchmark_with_batch_invariant(
|
||||
model=model,
|
||||
tp_size=tp_size,
|
||||
max_batch_size=max_batch_size,
|
||||
num_trials=num_trials,
|
||||
min_prompt=min_prompt,
|
||||
max_prompt=max_prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
gpu_mem_util=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
backend=backend,
|
||||
batch_invariant=True,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
print("\n" + "=" * 80)
|
||||
print("COMPARISON: Batch Invariance vs Baseline")
|
||||
print("=" * 80)
|
||||
|
||||
init_overhead_pct = (
|
||||
(batch_inv_results["init_time"] - baseline_results["init_time"])
|
||||
/ baseline_results["init_time"]
|
||||
* 100
|
||||
)
|
||||
time_overhead_pct = (
|
||||
(batch_inv_results["avg_time"] - baseline_results["avg_time"])
|
||||
/ baseline_results["avg_time"]
|
||||
* 100
|
||||
)
|
||||
throughput_change_pct = (
|
||||
(batch_inv_results["throughput"] - baseline_results["throughput"])
|
||||
/ baseline_results["throughput"]
|
||||
* 100
|
||||
)
|
||||
|
||||
print("\nInitialization Time:")
|
||||
print(f" Baseline: {baseline_results['init_time']:.2f}s")
|
||||
print(f" Batch Invariant: {batch_inv_results['init_time']:.2f}s")
|
||||
print(f" Overhead: {init_overhead_pct:+.2f}%")
|
||||
|
||||
print("\nAverage Trial Time:")
|
||||
print(f" Baseline: {baseline_results['avg_time']:.2f}s")
|
||||
print(f" Batch Invariant: {batch_inv_results['avg_time']:.2f}s")
|
||||
print(f" Overhead: {time_overhead_pct:+.2f}%")
|
||||
|
||||
print("\nThroughput (tokens/s):")
|
||||
print(f" Baseline: {baseline_results['throughput']:.2f}")
|
||||
print(f" Batch Invariant: {batch_inv_results['throughput']:.2f}")
|
||||
print(f" Change: {throughput_change_pct:+.2f}%")
|
||||
|
||||
print("\nPrompts/s:")
|
||||
print(f" Baseline: {baseline_results['prompts_per_sec']:.2f}")
|
||||
print(f" Batch Invariant: {batch_inv_results['prompts_per_sec']:.2f}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("SUMMARY")
|
||||
print("=" * 80)
|
||||
if time_overhead_pct > 0:
|
||||
print(
|
||||
f"Batch invariance mode adds approximately {time_overhead_pct:.1f}% "
|
||||
"overhead"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Batch invariance mode is approximately {-time_overhead_pct:.1f}% "
|
||||
"faster (unexpected!)"
|
||||
)
|
||||
|
||||
if abs(throughput_change_pct) < 1.0:
|
||||
print("Throughput difference is negligible (< 1%)")
|
||||
elif throughput_change_pct < 0:
|
||||
print(
|
||||
f"Throughput decreased by {-throughput_change_pct:.1f}% "
|
||||
"with batch invariance"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Throughput increased by {throughput_change_pct:.1f}% "
|
||||
"with batch invariance (unexpected!)"
|
||||
)
|
||||
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
@ -255,8 +255,8 @@ def bench_run(
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timing
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies = []
|
||||
for _ in range(num_iters):
|
||||
|
||||
@ -185,8 +185,8 @@ def benchmark_config(
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
|
||||
@ -105,8 +105,8 @@ def benchmark_permute(
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
@ -241,8 +241,8 @@ def benchmark_unpermute(
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
|
||||
@ -30,8 +30,8 @@ def _time_cuda(
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
for _ in range(bench_iters):
|
||||
|
||||
@ -253,8 +253,8 @@ def benchmark(
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
# Benchmark
|
||||
latencies: list[float] = []
|
||||
|
||||
@ -127,8 +127,8 @@ def benchmark_decode(
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
|
||||
@ -139,8 +139,8 @@ def benchmark_prefill(
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
|
||||
@ -183,8 +183,8 @@ def benchmark_config(
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
|
||||
@ -55,6 +55,10 @@ output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75
|
||||
----------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
If you run with `--warmup-step`, the summary will also include `warmup_runtime_sec`
|
||||
and `total_runtime_incl_warmup_sec` (while `runtime_sec` continues to reflect the
|
||||
benchmark-only runtime so the reported throughput stays comparable).
|
||||
|
||||
### JSON configuration file for synthetic conversations generation
|
||||
|
||||
The input flag `--input-file` is used to determine the input conversations for the benchmark.<br/>
|
||||
|
||||
@ -561,8 +561,11 @@ async def client_main(
|
||||
f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501
|
||||
)
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
# Set unique seed per client (each client runs in its own process)
|
||||
# Add 1 to ensure no client uses the same seed as the main process
|
||||
client_seed = args.seed + client_id + 1
|
||||
random.seed(client_seed)
|
||||
np.random.seed(client_seed)
|
||||
|
||||
# Active conversations
|
||||
active_convs: ConversationsMap = {}
|
||||
@ -1073,6 +1076,7 @@ def process_statistics(
|
||||
verbose: bool,
|
||||
gen_conv_args: GenConvArgs | None = None,
|
||||
excel_output: bool = False,
|
||||
warmup_runtime_sec: float | None = None,
|
||||
) -> None:
|
||||
if len(client_metrics) == 0:
|
||||
logger.info("No samples to process")
|
||||
@ -1166,8 +1170,13 @@ def process_statistics(
|
||||
# Convert milliseconds to seconds
|
||||
runtime_sec = runtime_sec / 1000.0
|
||||
requests_per_sec = float(len(df)) / runtime_sec
|
||||
|
||||
params = {"runtime_sec": runtime_sec, "requests_per_sec": requests_per_sec}
|
||||
params = {
|
||||
"runtime_sec": runtime_sec,
|
||||
"requests_per_sec": requests_per_sec,
|
||||
}
|
||||
if warmup_runtime_sec is not None:
|
||||
params["warmup_runtime_sec"] = warmup_runtime_sec
|
||||
params["total_runtime_incl_warmup_sec"] = runtime_sec + warmup_runtime_sec
|
||||
|
||||
# Generate a summary of relevant metrics (and drop irrelevant data)
|
||||
df = df.drop(columns=exclude).describe(percentiles=percentiles).transpose()
|
||||
@ -1490,6 +1499,7 @@ async def main() -> None:
|
||||
f"Invalid --warmup-percentage={args.warmup_percentage}"
|
||||
) from None
|
||||
|
||||
# Set global seeds for main process
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
@ -1548,6 +1558,8 @@ async def main() -> None:
|
||||
url=args.url, num_clients=args.num_clients, early_stop=not args.no_early_stop
|
||||
)
|
||||
|
||||
warmup_runtime_sec: float | None = None
|
||||
|
||||
# Warm-up step
|
||||
if args.warmup_step:
|
||||
# Only send a single user prompt from every conversation.
|
||||
@ -1562,26 +1574,56 @@ async def main() -> None:
|
||||
# all clients should finish their work before exiting
|
||||
warmup_bench_args = bench_args._replace(early_stop=False)
|
||||
|
||||
logger.info(f"{Color.PURPLE}Warmup start{Color.RESET}")
|
||||
logger.info("%sWarmup start%s", Color.PURPLE, Color.RESET)
|
||||
warmup_start_ns = time.perf_counter_ns()
|
||||
conversations, _ = await main_mp(
|
||||
warmup_client_args, req_args, warmup_bench_args, tokenizer, conversations
|
||||
)
|
||||
logger.info(f"{Color.PURPLE}Warmup done{Color.RESET}")
|
||||
warmup_runtime_sec = nanosec_to_sec(time.perf_counter_ns() - warmup_start_ns)
|
||||
logger.info(
|
||||
"%sWarmup runtime: %.3f sec (%.3f ms)%s",
|
||||
Color.PURPLE,
|
||||
warmup_runtime_sec,
|
||||
warmup_runtime_sec * 1000,
|
||||
Color.RESET,
|
||||
)
|
||||
logger.info("%sWarmup done%s", Color.PURPLE, Color.RESET)
|
||||
|
||||
# Run the benchmark
|
||||
start_time = time.perf_counter_ns()
|
||||
benchmark_start_ns = time.perf_counter_ns()
|
||||
client_convs, client_metrics = await main_mp(
|
||||
client_args, req_args, bench_args, tokenizer, conversations
|
||||
)
|
||||
total_runtime_ms = nanosec_to_millisec(time.perf_counter_ns() - start_time)
|
||||
benchmark_runtime_sec = nanosec_to_sec(time.perf_counter_ns() - benchmark_start_ns)
|
||||
|
||||
# Calculate requests per second
|
||||
total_runtime_sec = total_runtime_ms / 1000.0
|
||||
rps = len(client_metrics) / total_runtime_sec
|
||||
requests_per_sec = len(client_metrics) / benchmark_runtime_sec
|
||||
benchmark_runtime_ms = benchmark_runtime_sec * 1000.0
|
||||
logger.info(
|
||||
f"{Color.GREEN}All clients finished, total runtime: {total_runtime_sec:.3f} sec"
|
||||
f" ({total_runtime_ms:.3f} ms), requests per second: {rps:.3f}{Color.RESET}"
|
||||
"%sAll clients finished, benchmark runtime: %.3f sec (%.3f ms), "
|
||||
"requests per second: %.3f%s",
|
||||
Color.GREEN,
|
||||
benchmark_runtime_sec,
|
||||
benchmark_runtime_ms,
|
||||
requests_per_sec,
|
||||
Color.RESET,
|
||||
)
|
||||
if warmup_runtime_sec is not None:
|
||||
total_runtime_sec = benchmark_runtime_sec + warmup_runtime_sec
|
||||
logger.info(
|
||||
"%sWarmup runtime: %.3f sec (%.3f ms)%s",
|
||||
Color.GREEN,
|
||||
warmup_runtime_sec,
|
||||
warmup_runtime_sec * 1000,
|
||||
Color.RESET,
|
||||
)
|
||||
logger.info(
|
||||
"%sTotal runtime (including warmup): %.3f sec (%.3f ms)%s",
|
||||
Color.GREEN,
|
||||
total_runtime_sec,
|
||||
total_runtime_sec * 1000,
|
||||
Color.RESET,
|
||||
)
|
||||
|
||||
# Benchmark parameters
|
||||
params = {
|
||||
@ -1606,6 +1648,7 @@ async def main() -> None:
|
||||
verbose=args.verbose,
|
||||
gen_conv_args=gen_conv_args,
|
||||
excel_output=args.excel_output,
|
||||
warmup_runtime_sec=warmup_runtime_sec,
|
||||
)
|
||||
|
||||
if args.output_file is not None:
|
||||
|
||||
@ -375,6 +375,7 @@ set(VLLM_EXT_SRC
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/shm.cpp"
|
||||
"csrc/cpu/cpu_wna16.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
set(VLLM_EXT_SRC
|
||||
|
||||
53
cmake/external_projects/triton_kernels.cmake
Normal file
53
cmake/external_projects/triton_kernels.cmake
Normal file
@ -0,0 +1,53 @@
|
||||
# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels
|
||||
|
||||
set(DEFAULT_TRITON_KERNELS_TAG "v3.5.0")
|
||||
|
||||
# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to
|
||||
# be directly set to the triton_kernels python directory.
|
||||
if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
|
||||
message(STATUS "[triton_kernels] Fetch from $ENV{TRITON_KERNELS_SRC_DIR}")
|
||||
FetchContent_Declare(
|
||||
triton_kernels
|
||||
SOURCE_DIR $ENV{TRITON_KERNELS_SRC_DIR}
|
||||
)
|
||||
|
||||
else()
|
||||
set(TRITON_GIT "https://github.com/triton-lang/triton.git")
|
||||
message (STATUS "[triton_kernels] Fetch from ${TRITON_GIT}:${DEFAULT_TRITON_KERNELS_TAG}")
|
||||
FetchContent_Declare(
|
||||
triton_kernels
|
||||
# TODO (varun) : Fetch just the triton_kernels directory from Triton
|
||||
GIT_REPOSITORY https://github.com/triton-lang/triton.git
|
||||
GIT_TAG ${DEFAULT_TRITON_KERNELS_TAG}
|
||||
GIT_PROGRESS TRUE
|
||||
SOURCE_SUBDIR python/triton_kernels/triton_kernels
|
||||
)
|
||||
endif()
|
||||
|
||||
# Fetch content
|
||||
FetchContent_MakeAvailable(triton_kernels)
|
||||
|
||||
if (NOT triton_kernels_SOURCE_DIR)
|
||||
message (FATAL_ERROR "[triton_kernels] Cannot resolve triton_kernels_SOURCE_DIR")
|
||||
endif()
|
||||
|
||||
if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
|
||||
set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/")
|
||||
else()
|
||||
set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/python/triton_kernels/triton_kernels/")
|
||||
endif()
|
||||
|
||||
message (STATUS "[triton_kernels] triton_kernels is available at ${TRITON_KERNELS_PYTHON_DIR}")
|
||||
|
||||
add_custom_target(triton_kernels)
|
||||
|
||||
# Ensure the vllm/third_party directory exists before installation
|
||||
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/third_party/triton_kernels\")")
|
||||
|
||||
## Copy .py files to install directory.
|
||||
install(DIRECTORY
|
||||
${TRITON_KERNELS_PYTHON_DIR}
|
||||
DESTINATION
|
||||
vllm/third_party/triton_kernels/
|
||||
COMPONENT triton_kernels
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
@ -1,7 +1,6 @@
|
||||
#ifndef CPU_ATTN_HPP
|
||||
#define CPU_ATTN_HPP
|
||||
|
||||
#include <unistd.h>
|
||||
#include <type_traits>
|
||||
#include <cstddef>
|
||||
|
||||
@ -12,6 +11,7 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "scratchpad_manager.h"
|
||||
#include "cpu_attn_macros.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
enum class ISA { AMX, VEC, VEC16 };
|
||||
@ -754,7 +754,7 @@ class AttentionScheduler {
|
||||
return l2_cache_size >> 1; // use 50% of L2 cache
|
||||
}
|
||||
// Fallback if sysctlbyname fails
|
||||
return 128 * 1024 >> 1; // use 50% of 128KB
|
||||
return 128LL * 1024 >> 1; // use 50% of 128KB
|
||||
#else
|
||||
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
|
||||
TORCH_CHECK_NE(l2_cache_size, -1);
|
||||
|
||||
@ -104,6 +104,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
explicit FP16Vec16(bool, void* ptr)
|
||||
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit FP16Vec16(const c10::Half v) : reg(_mm256_set1_epi16(v.x)) {}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
@ -141,6 +143,8 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
explicit BF16Vec16(bool, void* ptr)
|
||||
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit BF16Vec16(const c10::BFloat16 v) : reg(_mm256_set1_epi16(v.x)) {}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
@ -350,6 +354,22 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||
|
||||
// de-pack 4 bit values
|
||||
explicit FP32Vec16(int64_t value, const FP32Vec16& lut) {
|
||||
int64_t mask_0 = 0x0F0F0F0F0F0F0F0F;
|
||||
int64_t mask_1 = 0xF0F0F0F0F0F0F0F0;
|
||||
int64_t value_0 = value & mask_0;
|
||||
int64_t value_1 = value & mask_1;
|
||||
__m128i vec_0 = _mm_movpi64_epi64((__m64)value_0);
|
||||
__m128i vec_1 = _mm_movpi64_epi64((__m64)value_1);
|
||||
vec_0 = _mm_cvtepu8_epi16(vec_0);
|
||||
vec_1 = _mm_cvtepu8_epi16(vec_1);
|
||||
vec_1 = _mm_slli_epi16(vec_1, 4);
|
||||
__m128i vec = _mm_or_si128(vec_0, vec_1);
|
||||
__m512i vec_i32 = _mm512_cvtepu8_epi32(vec);
|
||||
reg = _mm512_permutexvar_ps(vec_i32, lut.reg);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
: reg((__m512)_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(
|
||||
@ -426,14 +446,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
|
||||
return _mm512_mask_reduce_add_ps(mask, reg);
|
||||
}
|
||||
|
||||
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
@ -755,6 +767,25 @@ inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
|
||||
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
|
||||
_mm512_stream_ps((float*)ptr, vec.reg);
|
||||
}
|
||||
|
||||
static void interleave_save(const BF16Vec16& vec0, const BF16Vec16& vec1,
|
||||
void* ptr) {
|
||||
__m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
|
||||
__m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
|
||||
vec_1 = _mm512_slli_epi32(vec_1, 16);
|
||||
vec_0 = _mm512_or_si512(vec_0, vec_1);
|
||||
_mm512_storeu_epi32(ptr, vec_0);
|
||||
}
|
||||
|
||||
static void interleave_save(const FP16Vec16& vec0, const FP16Vec16& vec1,
|
||||
void* ptr) {
|
||||
__m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
|
||||
__m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
|
||||
vec_1 = _mm512_slli_epi32(vec_1, 16);
|
||||
vec_0 = _mm512_or_si512(vec_0, vec_1);
|
||||
_mm512_storeu_epi32(ptr, vec_0);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
inline void mem_barrier() { _mm_mfence(); }
|
||||
|
||||
402
csrc/cpu/cpu_wna16.cpp
Normal file
402
csrc/cpu/cpu_wna16.cpp
Normal file
@ -0,0 +1,402 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "scratchpad_manager.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
#ifdef CPU_CAPABILITY_AMXBF16
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
|
||||
#endif
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp"
|
||||
|
||||
#define VLLM_DISPATCH_CASE_16B_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_16B_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_16B_TYPES(__VA_ARGS__))
|
||||
|
||||
template <typename T>
|
||||
void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
|
||||
int32_t stride) {
|
||||
std::stringstream ss;
|
||||
ss << std::fixed << std::setprecision(5) << name << ": [\n";
|
||||
auto* curr_logits_buffer = ptr;
|
||||
for (int32_t m = 0; m < row; ++m) {
|
||||
for (int32_t n = 0; n < col; ++n) {
|
||||
ss << curr_logits_buffer[n] << ", ";
|
||||
}
|
||||
ss << "\n";
|
||||
curr_logits_buffer += stride;
|
||||
}
|
||||
ss << "]\n";
|
||||
std::printf("%s", ss.str().c_str());
|
||||
}
|
||||
|
||||
namespace {
|
||||
using cpu_utils::ISA;
|
||||
using cpu_utils::VecTypeTrait;
|
||||
|
||||
template <typename scalar_t, ISA isa, bool has_zp, bool use_desc_act>
|
||||
class Dequantizer4b {
|
||||
public:
|
||||
constexpr static int32_t pack_num = 32 / 4;
|
||||
using scalar_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
|
||||
|
||||
public:
|
||||
static void dequant(int32_t* __restrict__ q_weight,
|
||||
scalar_t* __restrict__ weight,
|
||||
scalar_t* __restrict__ scales,
|
||||
int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
|
||||
const int64_t scales_stride, const int64_t zeros_stride,
|
||||
const int32_t k_size, const int32_t group_size) {
|
||||
vec_op::FP32Vec16 lut;
|
||||
if constexpr (has_zp) {
|
||||
// AWQ
|
||||
alignas(64) static const float LUT[16] = {
|
||||
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
|
||||
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f};
|
||||
lut = vec_op::FP32Vec16(LUT);
|
||||
} else {
|
||||
// GPTQ
|
||||
alignas(64) static const float LUT[16] = {
|
||||
-8.0f, -7.0f, -6.0f, -5.0f, -4.0f, -3.0f, -2.0f, -1.0f,
|
||||
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
|
||||
lut = vec_op::FP32Vec16(LUT);
|
||||
}
|
||||
|
||||
// per 64-bits elem contains 16 output channels
|
||||
int64_t* __restrict__ curr_q_weight = reinterpret_cast<int64_t*>(q_weight);
|
||||
int64_t* __restrict__ curr_zeros = reinterpret_cast<int64_t*>(zeros);
|
||||
scalar_t* __restrict__ curr_weight = weight;
|
||||
scalar_t* __restrict__ curr_scale = scales;
|
||||
vec_op::FP32Vec16 scale_0;
|
||||
vec_op::FP32Vec16 scale_1;
|
||||
vec_op::FP32Vec16 zero_0;
|
||||
vec_op::FP32Vec16 zero_1;
|
||||
int32_t group_counter = 0;
|
||||
for (int32_t k_idx = 0; k_idx < k_size; k_idx += 2) {
|
||||
int64_t qwb_0 = *curr_q_weight;
|
||||
int64_t qwb_1 = *(curr_q_weight + 1);
|
||||
vec_op::FP32Vec16 wb_0(qwb_0, lut);
|
||||
vec_op::FP32Vec16 wb_1(qwb_1, lut);
|
||||
|
||||
if constexpr (!use_desc_act) {
|
||||
if (group_counter == 0) {
|
||||
scale_0 = vec_op::FP32Vec16(scalar_vec_t(curr_scale));
|
||||
scale_1 = vec_op::FP32Vec16(scale_0);
|
||||
curr_scale += scales_stride;
|
||||
|
||||
if constexpr (has_zp) {
|
||||
zero_0 = vec_op::FP32Vec16(*curr_zeros, lut);
|
||||
zero_1 = vec_op::FP32Vec16(zero_0);
|
||||
curr_zeros += zeros_stride / 2;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int32_t g_idx_0 = g_idx[k_idx];
|
||||
int32_t g_idx_1 = g_idx[k_idx + 1];
|
||||
scale_0 = vec_op::FP32Vec16(
|
||||
scalar_vec_t(curr_scale + g_idx_0 * scales_stride));
|
||||
scale_1 = vec_op::FP32Vec16(
|
||||
scalar_vec_t(curr_scale + g_idx_1 * scales_stride));
|
||||
if constexpr (has_zp) {
|
||||
zero_0 = vec_op::FP32Vec16(*(curr_zeros + g_idx_0 * zeros_stride / 2),
|
||||
lut);
|
||||
zero_1 = vec_op::FP32Vec16(*(curr_zeros + g_idx_1 * zeros_stride / 2),
|
||||
lut);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (has_zp) {
|
||||
wb_0 = wb_0 - zero_0;
|
||||
wb_1 = wb_1 - zero_1;
|
||||
}
|
||||
|
||||
wb_0 = wb_0 * scale_0;
|
||||
wb_1 = wb_1 * scale_1;
|
||||
|
||||
scalar_vec_t output_vec_0(wb_0);
|
||||
scalar_vec_t output_vec_1(wb_1);
|
||||
|
||||
// AMX needs to interlave K elements to pack as 32 bits
|
||||
if constexpr (isa == ISA::AMX) {
|
||||
vec_op::interleave_save(output_vec_0, output_vec_1, curr_weight);
|
||||
} else {
|
||||
output_vec_0.save(curr_weight);
|
||||
output_vec_1.save(curr_weight + 16);
|
||||
}
|
||||
|
||||
// update
|
||||
curr_q_weight += 2;
|
||||
curr_weight += 32;
|
||||
if constexpr (!use_desc_act) {
|
||||
group_counter += 2;
|
||||
if (group_counter == group_size) {
|
||||
group_counter = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}; // namespace
|
||||
|
||||
template <typename scalar_t, typename dequantizer_t, typename gemm_t>
|
||||
void cpu_gemm_wna16_impl(
|
||||
scalar_t* __restrict__ input, int32_t* __restrict__ q_weight,
|
||||
scalar_t* __restrict__ output, scalar_t* __restrict__ scales,
|
||||
int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
|
||||
scalar_t* __restrict__ bias, const int32_t m_size, const int32_t n_size,
|
||||
const int32_t k_size, const int64_t input_stride,
|
||||
const int64_t output_stride, const int64_t scales_group_stride,
|
||||
const int64_t zeros_group_stride, const int32_t group_num,
|
||||
const int32_t group_size, const int64_t pack_factor) {
|
||||
constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
|
||||
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
|
||||
constexpr int32_t n_block_size = 16;
|
||||
static_assert(gemm_n_tile_size % n_block_size == 0);
|
||||
const int32_t thread_num = omp_get_max_threads();
|
||||
|
||||
// a simple schedule policy, just to hold more B tiles in L2 and make sure
|
||||
// each thread has tasks
|
||||
const int32_t n_partition_size = [&]() {
|
||||
const int64_t cache_size = cpu_utils::get_l2_size();
|
||||
int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
|
||||
int64_t ps_thread_limit = n_size / thread_num;
|
||||
ps_cache_limit =
|
||||
std::max((ps_cache_limit / gemm_n_tile_size) * gemm_n_tile_size,
|
||||
(int64_t)gemm_n_tile_size);
|
||||
ps_thread_limit =
|
||||
std::max((ps_thread_limit / gemm_n_tile_size) * gemm_n_tile_size,
|
||||
(int64_t)gemm_n_tile_size);
|
||||
return std::min(ps_cache_limit, ps_thread_limit);
|
||||
}();
|
||||
const int32_t task_num = (n_size + n_partition_size - 1) / n_partition_size;
|
||||
|
||||
// get buffer size
|
||||
const int64_t b_buffer_size =
|
||||
(((n_partition_size * k_size * sizeof(scalar_t) + 63) / 64) * 64);
|
||||
const int64_t c_buffer_size =
|
||||
(((gemm_m_tile_size * gemm_n_tile_size * sizeof(float) + 63) / 64) * 64);
|
||||
const int64_t b_buffer_offset = 0;
|
||||
const int64_t c_buffer_offset = b_buffer_size;
|
||||
const int64_t buffer_size = b_buffer_size + c_buffer_size;
|
||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size *
|
||||
thread_num);
|
||||
|
||||
alignas(64) cpu_utils::Counter counter;
|
||||
cpu_utils::Counter* counter_ptr = &counter;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
|
||||
scalar_t* __restrict__ b_buffer = nullptr;
|
||||
float* __restrict__ c_buffer = nullptr;
|
||||
{
|
||||
uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager()
|
||||
->get_data<uint8_t>() +
|
||||
thread_id * buffer_size;
|
||||
b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset);
|
||||
c_buffer = reinterpret_cast<float*>(buffer_ptr + c_buffer_offset);
|
||||
}
|
||||
|
||||
const int64_t q_weight_block_stride = n_block_size / pack_factor * k_size;
|
||||
const int64_t b_buffer_block_stride = n_block_size * k_size;
|
||||
const int32_t zeros_block_stride = n_block_size / pack_factor;
|
||||
|
||||
gemm_t gemm;
|
||||
|
||||
for (;;) {
|
||||
int32_t task_id = counter_ptr->acquire_counter();
|
||||
|
||||
if (task_id >= task_num) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int32_t n_start_idx = task_id * n_partition_size;
|
||||
const int32_t n_block_start_idx = n_start_idx / n_block_size;
|
||||
const int32_t n_num = std::min(n_partition_size, n_size - n_start_idx);
|
||||
const int32_t n_block_num = n_num / n_block_size;
|
||||
// std::printf("thread_id: %d, task_id: %d, n_start_idx: %d, n_num: %d\n",
|
||||
// thread_id, task_id, n_start_idx, n_num);
|
||||
|
||||
// dequant weight
|
||||
{
|
||||
int32_t* __restrict__ curr_q_weight =
|
||||
q_weight + n_block_start_idx * q_weight_block_stride;
|
||||
scalar_t* __restrict__ curr_b_buffer = b_buffer;
|
||||
scalar_t* __restrict__ curr_scales = scales + n_start_idx;
|
||||
int32_t* __restrict__ curr_zeros = zeros + n_start_idx / pack_factor;
|
||||
for (int32_t block_idx = 0; block_idx < n_block_num; ++block_idx) {
|
||||
dequantizer_t::dequant(curr_q_weight, curr_b_buffer, curr_scales,
|
||||
curr_zeros, g_idx, scales_group_stride,
|
||||
zeros_group_stride, k_size, group_size);
|
||||
|
||||
// if (block_idx == 0 && n_start_idx == 0) {
|
||||
// print_logits("depacked weight", curr_b_buffer, k_size,
|
||||
// n_block_size, n_block_size);
|
||||
// }
|
||||
|
||||
// update
|
||||
curr_q_weight += q_weight_block_stride;
|
||||
curr_b_buffer += b_buffer_block_stride;
|
||||
curr_scales += n_block_size;
|
||||
curr_zeros += zeros_block_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// compute loop
|
||||
{
|
||||
const int32_t n_tile_num = n_num / gemm_n_tile_size;
|
||||
scalar_t* __restrict__ curr_input = input;
|
||||
scalar_t* __restrict__ init_bias = bias;
|
||||
if (bias != nullptr) {
|
||||
init_bias += n_start_idx;
|
||||
}
|
||||
scalar_t* __restrict__ init_output = output + n_start_idx;
|
||||
for (int32_t m_idx = 0; m_idx < m_size; m_idx += gemm_m_tile_size) {
|
||||
const int32_t curr_m_size =
|
||||
std::min(gemm_m_tile_size, m_size - m_idx);
|
||||
scalar_t* __restrict__ curr_b_buffer = b_buffer;
|
||||
scalar_t* __restrict__ curr_bias = init_bias;
|
||||
scalar_t* __restrict__ curr_output = init_output;
|
||||
for (int32_t n_tile_idx = 0; n_tile_idx < n_tile_num; ++n_tile_idx) {
|
||||
gemm.gemm(curr_input, curr_b_buffer, c_buffer, curr_m_size, k_size,
|
||||
input_stride, b_buffer_block_stride, gemm_n_tile_size,
|
||||
false);
|
||||
|
||||
if (bias != nullptr) {
|
||||
cpu_micro_gemm::bias_epilogue<gemm_n_tile_size>(
|
||||
c_buffer, curr_output, curr_bias, curr_m_size,
|
||||
gemm_n_tile_size, output_stride);
|
||||
curr_bias += gemm_n_tile_size;
|
||||
} else {
|
||||
cpu_micro_gemm::default_epilogue<gemm_n_tile_size>(
|
||||
c_buffer, curr_output, curr_m_size, gemm_n_tile_size,
|
||||
output_stride);
|
||||
}
|
||||
|
||||
curr_b_buffer +=
|
||||
b_buffer_block_stride * (gemm_n_tile_size / n_block_size);
|
||||
curr_output += gemm_n_tile_size;
|
||||
}
|
||||
curr_input += gemm_m_tile_size * input_stride;
|
||||
init_output += gemm_m_tile_size * output_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cpu_gemm_wna16(
|
||||
const torch::Tensor& input, // [M, K]
|
||||
const torch::Tensor&
|
||||
q_weight, // [N / 16, K * 16 / pack_factor], packed as int32
|
||||
torch::Tensor& output, // [M, N]
|
||||
const torch::Tensor& scales, // [group_num, N]
|
||||
const std::optional<torch::Tensor>&
|
||||
zeros, // [group_num, N / pack_factor], packed as int32
|
||||
const std::optional<torch::Tensor>& g_idx, // [K]
|
||||
const std::optional<torch::Tensor>& bias, // [N]
|
||||
const int64_t pack_factor, const std::string& isa_hint) {
|
||||
using cpu_utils::ISA;
|
||||
TORCH_CHECK_EQ(pack_factor, 8); // only supports 4bits
|
||||
const int32_t a_m_size = input.size(0);
|
||||
const int32_t a_k_size = input.size(1);
|
||||
const int64_t a_m_stride = input.stride(0);
|
||||
const int32_t b_n_size = q_weight.size(0) * 16;
|
||||
TORCH_CHECK_EQ(a_k_size % 32, 0);
|
||||
TORCH_CHECK_EQ(b_n_size % 32, 0);
|
||||
const int32_t group_num = scales.size(0);
|
||||
const int32_t group_size = a_k_size / group_num;
|
||||
TORCH_CHECK_EQ(group_size % 2, 0);
|
||||
const int64_t scales_group_stride = scales.stride(0);
|
||||
const int64_t output_m_stride = output.stride(0);
|
||||
|
||||
bool has_zp = zeros.has_value();
|
||||
bool use_desc_act = g_idx.has_value();
|
||||
TORCH_CHECK(!(has_zp && use_desc_act));
|
||||
|
||||
ISA isa = [&]() {
|
||||
if (isa_hint == "amx") {
|
||||
return ISA::AMX;
|
||||
} else if (isa_hint == "vec") {
|
||||
return ISA::VEC;
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported isa hint: " + isa_hint);
|
||||
}
|
||||
}();
|
||||
|
||||
int32_t* zeros_ptr = has_zp ? zeros->data_ptr<int32_t>() : nullptr;
|
||||
const int64_t zeros_group_stride = has_zp ? zeros->stride(0) : 0;
|
||||
int32_t* g_idx_ptr = use_desc_act ? g_idx->data_ptr<int32_t>() : nullptr;
|
||||
|
||||
VLLM_DISPATCH_16B_TYPES(input.scalar_type(), "cpu_gemm_wna16", [&]() {
|
||||
if (isa == ISA::AMX) {
|
||||
using gemm_t = cpu_micro_gemm::MicroGemm<ISA::AMX, scalar_t>;
|
||||
if (has_zp) {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, true, false>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
}
|
||||
if (use_desc_act) {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, false, true>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
} else {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, false, false>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
}
|
||||
} else if (isa == ISA::VEC) {
|
||||
using gemm_t = cpu_micro_gemm::MicroGemm<ISA::VEC, scalar_t>;
|
||||
if (has_zp) {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, true, false>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
}
|
||||
if (use_desc_act) {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, false, true>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
} else {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, false, false>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -396,9 +396,9 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
|
||||
: DNNLMatMulPrimitiveHandler(
|
||||
static_cast<DNNLMatMulPrimitiveHandler::Args>(args), args.ab_type),
|
||||
m_size_cache_(nullptr) {
|
||||
assert(ab_type_ == dnnl::memory::data_type::f32 ||
|
||||
ab_type_ == dnnl::memory::data_type::bf16 ||
|
||||
ab_type_ == dnnl::memory::data_type::f16);
|
||||
assert(b_type_ == dnnl::memory::data_type::f32 ||
|
||||
b_type_ == dnnl::memory::data_type::bf16 ||
|
||||
b_type_ == dnnl::memory::data_type::f16);
|
||||
|
||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||
{b_k_stride_, b_n_stride_});
|
||||
|
||||
245
csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
Normal file
245
csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
Normal file
@ -0,0 +1,245 @@
|
||||
#ifndef CPU_MICRO_GEMM_AMX_HPP
|
||||
#define CPU_MICRO_GEMM_AMX_HPP
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
|
||||
|
||||
namespace cpu_micro_gemm {
|
||||
namespace {
|
||||
// AMX specific
|
||||
constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
|
||||
constexpr static int64_t AMX_TILE_ROW_NUM = 16;
|
||||
constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
|
||||
|
||||
typedef struct __tile_config {
|
||||
uint8_t palette_id = 1;
|
||||
uint8_t start_row = 0;
|
||||
uint8_t reserved_0[14] = {0};
|
||||
uint16_t colsb[16] = {0};
|
||||
uint8_t rows[16] = {0};
|
||||
} __tilecfg;
|
||||
|
||||
// 2-2-4 pattern, for 16 < m <= 32
|
||||
// TILE 0, 1: load A matrix, row num should be 16, m - 16
|
||||
// TILE 2, 3: load B matrix, row num should be 16
|
||||
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
|
||||
// - 16
|
||||
template <typename scalar_t>
|
||||
class TileGemm224 {
|
||||
public:
|
||||
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
TORCH_CHECK(false, "Unsupported data type for TileGemm224");
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
TORCH_CHECK(false, "Unsupported data type for TileGemm224");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TileGemm224<c10::BFloat16> {
|
||||
public:
|
||||
using scalar_t = c10::BFloat16;
|
||||
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
|
||||
c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
|
||||
c10::BFloat16* __restrict__ a_tile_1 = a_ptr + lda * AMX_TILE_ROW_NUM;
|
||||
const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
|
||||
|
||||
// B is always packed as 16 output channels block
|
||||
c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
|
||||
c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
|
||||
const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
|
||||
|
||||
float* __restrict__ c_tile_4 = c_ptr;
|
||||
float* __restrict__ c_tile_5 =
|
||||
c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
float* __restrict__ c_tile_6 = c_ptr + AMX_TILE_ROW_NUM * ldc;
|
||||
float* __restrict__ c_tile_7 =
|
||||
c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
const int32_t c_tile_stride = ldc * sizeof(float);
|
||||
|
||||
if (accum_c) {
|
||||
_tile_loadd(4, c_tile_4, c_tile_stride);
|
||||
_tile_loadd(5, c_tile_5, c_tile_stride);
|
||||
_tile_loadd(6, c_tile_6, c_tile_stride);
|
||||
_tile_loadd(7, c_tile_7, c_tile_stride);
|
||||
} else {
|
||||
_tile_zero(4);
|
||||
_tile_zero(5);
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k < k_times; ++k) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_tile_stride);
|
||||
_tile_dpbf16ps(4, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_tile_stride);
|
||||
_tile_dpbf16ps(5, 0, 3);
|
||||
_tile_loadd(1, a_tile_1, a_tile_stride);
|
||||
_tile_dpbf16ps(6, 1, 2);
|
||||
_tile_dpbf16ps(7, 1, 3);
|
||||
|
||||
// update ptrs
|
||||
a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
}
|
||||
|
||||
_tile_stored(4, c_tile_4, c_tile_stride);
|
||||
_tile_stored(5, c_tile_5, c_tile_stride);
|
||||
_tile_stored(6, c_tile_6, c_tile_stride);
|
||||
_tile_stored(7, c_tile_7, c_tile_stride);
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
const int32_t m_0 = AMX_TILE_ROW_NUM;
|
||||
const int32_t m_1 = m - AMX_TILE_ROW_NUM;
|
||||
config.rows[0] = m_0;
|
||||
config.rows[1] = m_1;
|
||||
config.rows[2] = AMX_TILE_ROW_NUM;
|
||||
config.rows[3] = AMX_TILE_ROW_NUM;
|
||||
config.rows[4] = m_0;
|
||||
config.rows[5] = m_0;
|
||||
config.rows[6] = m_1;
|
||||
config.rows[7] = m_1;
|
||||
_tile_loadconfig(&config);
|
||||
}
|
||||
};
|
||||
|
||||
// 1-2-2 pattern, for 0 < m <= 16
|
||||
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
|
||||
// m, m
|
||||
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
|
||||
// num should be 16
|
||||
// TILE 6, 7, (6, 7): store results C matrix, row num should be
|
||||
// m
|
||||
template <typename scalar_t>
|
||||
class TileGemm122 {
|
||||
public:
|
||||
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
TORCH_CHECK(false, "Unsupported data type for TileGemm122");
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
TORCH_CHECK(false, "Unsupported data type for TileGemm122");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TileGemm122<c10::BFloat16> {
|
||||
public:
|
||||
using scalar_t = c10::BFloat16;
|
||||
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
|
||||
c10::BFloat16* __restrict__ a_tile_1 =
|
||||
a_ptr + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
|
||||
|
||||
c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
|
||||
c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
|
||||
c10::BFloat16* __restrict__ b_tile_4 =
|
||||
b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
c10::BFloat16* __restrict__ b_tile_5 =
|
||||
b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
int64_t b_stride = AMX_TILE_ROW_BYTES;
|
||||
|
||||
float* __restrict__ c_tile_6 = c_ptr;
|
||||
float* __restrict__ c_tile_7 = c_ptr + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
int64_t c_stride = ldc * sizeof(float);
|
||||
|
||||
const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
|
||||
const int32_t k_group_times = k_times / 2;
|
||||
const bool has_tail = (k_times % 2 == 1);
|
||||
|
||||
if (accum_c) {
|
||||
_tile_loadd(6, c_tile_6, c_stride);
|
||||
_tile_loadd(7, c_tile_7, c_stride);
|
||||
} else {
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k < k_group_times; ++k) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_stride);
|
||||
_tile_dpbf16ps(6, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_stride);
|
||||
_tile_dpbf16ps(7, 0, 3);
|
||||
_tile_loadd(1, a_tile_1, a_tile_stride);
|
||||
_tile_stream_loadd(4, b_tile_4, b_stride);
|
||||
_tile_dpbf16ps(6, 1, 4);
|
||||
_tile_stream_loadd(5, b_tile_5, b_stride);
|
||||
_tile_dpbf16ps(7, 1, 5);
|
||||
|
||||
// update ptrs
|
||||
a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
}
|
||||
|
||||
if (has_tail) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_stride);
|
||||
_tile_dpbf16ps(6, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_stride);
|
||||
_tile_dpbf16ps(7, 0, 3);
|
||||
}
|
||||
|
||||
_tile_stored(6, c_tile_6, c_stride);
|
||||
_tile_stored(7, c_tile_7, c_stride);
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
config.rows[0] = m;
|
||||
config.rows[1] = m;
|
||||
config.rows[2] = AMX_TILE_ROW_NUM;
|
||||
config.rows[3] = AMX_TILE_ROW_NUM;
|
||||
config.rows[4] = AMX_TILE_ROW_NUM;
|
||||
config.rows[5] = AMX_TILE_ROW_NUM;
|
||||
config.rows[6] = m;
|
||||
config.rows[7] = m;
|
||||
_tile_loadconfig(&config);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Gemm kernel uses AMX, requires B matrix to be packed
|
||||
template <typename scalar_t>
|
||||
class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
|
||||
public:
|
||||
static constexpr int32_t MaxMSize = 32;
|
||||
static constexpr int32_t NSize = 32;
|
||||
|
||||
public:
|
||||
MicroGemm() : curr_m_(-1) {
|
||||
vec_op::unroll_loop<int, 8>([&](int i) { amx_tile_config_.colsb[i] = 64; });
|
||||
}
|
||||
|
||||
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
if (m > AMX_TILE_ROW_NUM) {
|
||||
if (m != curr_m_) {
|
||||
curr_m_ = m;
|
||||
TileGemm224<scalar_t>::init_tile_config(m, amx_tile_config_);
|
||||
}
|
||||
TileGemm224<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
|
||||
} else {
|
||||
if (m != curr_m_) {
|
||||
curr_m_ = m;
|
||||
TileGemm122<scalar_t>::init_tile_config(m, amx_tile_config_);
|
||||
}
|
||||
TileGemm122<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
alignas(64) __tilecfg amx_tile_config_;
|
||||
int32_t curr_m_;
|
||||
};
|
||||
|
||||
} // namespace cpu_micro_gemm
|
||||
|
||||
#endif
|
||||
91
csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
Normal file
91
csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
Normal file
@ -0,0 +1,91 @@
|
||||
#ifndef CPU_MICRO_GEMM_IMPL_HPP
|
||||
#define CPU_MICRO_GEMM_IMPL_HPP
|
||||
#include "cpu/utils.hpp"
|
||||
#include "cpu/cpu_types.hpp"
|
||||
|
||||
namespace cpu_micro_gemm {
|
||||
#define DEFINE_CPU_MICRO_GEMM_PARAMS \
|
||||
scalar_t *__restrict__ a_ptr, scalar_t *__restrict__ b_ptr, \
|
||||
float *__restrict__ c_ptr, const int32_t m, const int32_t k, \
|
||||
const int64_t lda, const int64_t b_n_group_stride, const int64_t ldc, \
|
||||
const bool accum_c
|
||||
|
||||
#define CPU_MICRO_GEMM_PARAMS \
|
||||
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
|
||||
|
||||
template <cpu_utils::ISA isa, typename scalar_t>
|
||||
class MicroGemm {
|
||||
public:
|
||||
static constexpr int32_t MaxMSize = 16;
|
||||
static constexpr int32_t NSize = 16;
|
||||
|
||||
public:
|
||||
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
TORCH_CHECK(false, "Unimplemented MicroGemm.");
|
||||
}
|
||||
};
|
||||
|
||||
template <int32_t n_size, typename scalar_t>
|
||||
FORCE_INLINE void default_epilogue(float* __restrict__ c_ptr,
|
||||
scalar_t* __restrict__ d_ptr,
|
||||
const int32_t m, const int64_t ldc,
|
||||
const int64_t ldd) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
static_assert(n_size % 16 == 0);
|
||||
|
||||
float* __restrict__ curr_c = c_ptr;
|
||||
scalar_t* __restrict__ curr_d = d_ptr;
|
||||
for (int32_t i = 0; i < m; ++i) {
|
||||
float* __restrict__ curr_c_iter = curr_c;
|
||||
scalar_t* __restrict__ curr_d_iter = curr_d;
|
||||
vec_op::unroll_loop<int32_t, n_size / 16>([&](int32_t n_g_idx) {
|
||||
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
|
||||
scalar_vec_t c_vec(c_vec_fp32);
|
||||
c_vec.save(curr_d_iter);
|
||||
curr_c_iter += 16;
|
||||
curr_d_iter += 16;
|
||||
});
|
||||
curr_c += ldc;
|
||||
curr_d += ldd;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t n_size, typename scalar_t>
|
||||
FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
|
||||
scalar_t* __restrict__ d_ptr,
|
||||
scalar_t* __restrict__ bias_ptr,
|
||||
const int32_t m, const int64_t ldc,
|
||||
const int64_t ldd) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
static_assert(n_size % 16 == 0);
|
||||
constexpr int32_t n_group_num = n_size / 16;
|
||||
static_assert(n_group_num <= 16);
|
||||
|
||||
vec_op::FP32Vec16 bias_vecs[n_group_num];
|
||||
scalar_t* __restrict__ curr_bias = bias_ptr;
|
||||
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t i) {
|
||||
scalar_vec_t vec(curr_bias);
|
||||
bias_vecs[i] = vec_op::FP32Vec16(vec);
|
||||
curr_bias += 16;
|
||||
});
|
||||
|
||||
float* __restrict__ curr_c = c_ptr;
|
||||
scalar_t* __restrict__ curr_d = d_ptr;
|
||||
for (int32_t i = 0; i < m; ++i) {
|
||||
float* __restrict__ curr_c_iter = curr_c;
|
||||
scalar_t* __restrict__ curr_d_iter = curr_d;
|
||||
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t n_g_idx) {
|
||||
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
|
||||
c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
|
||||
scalar_vec_t c_vec(c_vec_fp32);
|
||||
c_vec.save(curr_d_iter);
|
||||
curr_c_iter += 16;
|
||||
curr_d_iter += 16;
|
||||
});
|
||||
curr_c += ldc;
|
||||
curr_d += ldd;
|
||||
}
|
||||
}
|
||||
} // namespace cpu_micro_gemm
|
||||
|
||||
#endif
|
||||
115
csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
Normal file
115
csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
Normal file
@ -0,0 +1,115 @@
|
||||
#ifndef CPU_MICRO_GEMM_VEC_HPP
|
||||
#define CPU_MICRO_GEMM_VEC_HPP
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
|
||||
|
||||
namespace cpu_micro_gemm {
|
||||
namespace {
|
||||
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
|
||||
template <typename scalar_t>
|
||||
class TileGemm82 {
|
||||
public:
|
||||
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
switch (m) {
|
||||
case 1:
|
||||
gemm_micro<1>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 2:
|
||||
gemm_micro<2>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 3:
|
||||
gemm_micro<3>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 4:
|
||||
gemm_micro<4>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 5:
|
||||
gemm_micro<5>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 6:
|
||||
gemm_micro<6>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 7:
|
||||
gemm_micro<7>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 8:
|
||||
gemm_micro<8>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t M>
|
||||
static void gemm_micro(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
static_assert(0 < M <= 8);
|
||||
using load_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
|
||||
scalar_t* __restrict__ curr_b_0 = b_ptr;
|
||||
scalar_t* __restrict__ curr_b_1 = b_ptr + b_n_group_stride;
|
||||
float* __restrict__ curr_c_0 = c_ptr;
|
||||
float* __restrict__ curr_c_1 = c_ptr + 16;
|
||||
|
||||
vec_op::FP32Vec16 c_regs[M * 2];
|
||||
if (accum_c) {
|
||||
float* __restrict__ curr_m_c_0 = curr_c_0;
|
||||
float* __restrict__ curr_m_c_1 = curr_c_1;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
|
||||
c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
|
||||
|
||||
// update
|
||||
curr_m_c_0 += ldc;
|
||||
curr_m_c_1 += ldc;
|
||||
});
|
||||
}
|
||||
|
||||
scalar_t* __restrict__ curr_a = a_ptr;
|
||||
for (int32_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
load_vec_t b_0_reg(curr_b_0);
|
||||
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
|
||||
load_vec_t b_1_reg(curr_b_1);
|
||||
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
|
||||
|
||||
scalar_t* __restrict__ curr_m_a = curr_a;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
scalar_t v = *curr_m_a;
|
||||
load_vec_t a_reg_original(v);
|
||||
vec_op::FP32Vec16 a_reg(a_reg_original);
|
||||
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
|
||||
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
|
||||
|
||||
// update
|
||||
curr_m_a += lda;
|
||||
});
|
||||
|
||||
// update
|
||||
curr_a += 1;
|
||||
curr_b_0 += 16;
|
||||
curr_b_1 += 16;
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i * 2].save(curr_c_0);
|
||||
c_regs[i * 2 + 1].save(curr_c_1);
|
||||
|
||||
// update
|
||||
curr_c_0 += ldc;
|
||||
curr_c_1 += ldc;
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Gemm kernel uses vector instructions, requires B matrix to be packed
|
||||
template <typename scalar_t>
|
||||
class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
|
||||
public:
|
||||
static constexpr int32_t MaxMSize = 8;
|
||||
static constexpr int32_t NSize = 32;
|
||||
|
||||
public:
|
||||
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
|
||||
}
|
||||
};
|
||||
} // namespace cpu_micro_gemm
|
||||
|
||||
#endif
|
||||
@ -100,6 +100,16 @@ void cpu_attention_with_kv_cache(
|
||||
const torch::Tensor& scheduler_metadata,
|
||||
const std::optional<torch::Tensor>& s_aux);
|
||||
|
||||
// Note: just for avoiding importing errors
|
||||
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }
|
||||
|
||||
void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
|
||||
torch::Tensor& output, const torch::Tensor& scales,
|
||||
const std::optional<torch::Tensor>& zeros,
|
||||
const std::optional<torch::Tensor>& g_idx,
|
||||
const std::optional<torch::Tensor>& bias,
|
||||
const int64_t pack_factor, const std::string& isa_hint);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -275,6 +285,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
|
||||
"float softcap, Tensor sheduler_metadata, Tensor? s_aux) -> ()",
|
||||
&cpu_attention_with_kv_cache);
|
||||
|
||||
// placeholders
|
||||
ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
|
||||
ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
|
||||
ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
|
||||
|
||||
// WNA16
|
||||
#if defined(__AVX512F__)
|
||||
ops.def(
|
||||
"cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
|
||||
"Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
|
||||
"pack_factor, str isa_hint) -> ()");
|
||||
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
|
||||
55
csrc/cpu/utils.hpp
Normal file
55
csrc/cpu/utils.hpp
Normal file
@ -0,0 +1,55 @@
|
||||
#ifndef UTILS_HPP
|
||||
#define UTILS_HPP
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
namespace cpu_utils {
|
||||
enum class ISA { AMX, VEC };
|
||||
|
||||
template <typename T>
|
||||
struct VecTypeTrait {
|
||||
using vec_t = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecTypeTrait<float> {
|
||||
using vec_t = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecTypeTrait<c10::BFloat16> {
|
||||
using vec_t = vec_op::BF16Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecTypeTrait<c10::Half> {
|
||||
using vec_t = vec_op::FP16Vec16;
|
||||
};
|
||||
|
||||
struct Counter {
|
||||
std::atomic<int64_t> counter;
|
||||
char _padding[56];
|
||||
|
||||
Counter() : counter(0) {}
|
||||
|
||||
void reset_counter() { counter.store(0); }
|
||||
|
||||
int64_t acquire_counter() { return counter++; }
|
||||
};
|
||||
|
||||
inline int64_t get_l2_size() {
|
||||
static int64_t size = []() {
|
||||
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
|
||||
assert(l2_cache_size != -1);
|
||||
return l2_cache_size >> 1; // use 50% of L2 cache
|
||||
}();
|
||||
return size;
|
||||
}
|
||||
} // namespace cpu_utils
|
||||
|
||||
#endif
|
||||
@ -802,7 +802,7 @@ torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) {
|
||||
});
|
||||
|
||||
if (numel % 256 != 0) {
|
||||
out = out.index({torch::indexing::Slice(0, numel / had_size)});
|
||||
out = out.narrow(0, 0, numel / had_size);
|
||||
}
|
||||
|
||||
if (inplace && out.data_ptr() != x.data_ptr()) {
|
||||
|
||||
@ -14,6 +14,7 @@ RUN apt clean && apt-get update -y && \
|
||||
libxext6 \
|
||||
libgl1 \
|
||||
lsb-release \
|
||||
libaio-dev \
|
||||
numactl \
|
||||
wget \
|
||||
vim \
|
||||
@ -68,8 +69,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
# install nixl from source code
|
||||
ENV NIXL_VERSION=0.7.0
|
||||
RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages/.nixl.mesonpy.libs/plugins/"
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip uninstall oneccl oneccl-devel -y
|
||||
|
||||
@ -30,8 +30,8 @@ Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at
|
||||
Where to get started with vLLM depends on the type of user. If you are looking to:
|
||||
|
||||
- Run open-source models on vLLM, we recommend starting with the [Quickstart Guide](./getting_started/quickstart.md)
|
||||
- Build applications with vLLM, we recommend starting with the [User Guide](./usage)
|
||||
- Build vLLM, we recommend starting with [Developer Guide](./contributing)
|
||||
- Build applications with vLLM, we recommend starting with the [User Guide](./usage/README.md)
|
||||
- Build vLLM, we recommend starting with [Developer Guide](./contributing/README.md)
|
||||
|
||||
For information about the development of vLLM, see:
|
||||
|
||||
|
||||
@ -4,6 +4,6 @@
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
## Arguments
|
||||
|
||||
--8<-- "docs/argparse/bench_latency.md"
|
||||
--8<-- "docs/argparse/bench_latency.inc.md"
|
||||
|
||||
@ -4,6 +4,6 @@
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
## Arguments
|
||||
|
||||
--8<-- "docs/argparse/bench_serve.md"
|
||||
--8<-- "docs/argparse/bench_serve.inc.md"
|
||||
|
||||
@ -4,6 +4,6 @@
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
## Arguments
|
||||
|
||||
--8<-- "docs/argparse/bench_sweep_plot.md"
|
||||
--8<-- "docs/argparse/bench_sweep_plot.inc.md"
|
||||
|
||||
@ -4,6 +4,6 @@
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
## Arguments
|
||||
|
||||
--8<-- "docs/argparse/bench_sweep_serve.md"
|
||||
--8<-- "docs/argparse/bench_sweep_serve.inc.md"
|
||||
|
||||
@ -4,6 +4,6 @@
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
## Arguments
|
||||
|
||||
--8<-- "docs/argparse/bench_sweep_serve_sla.md"
|
||||
--8<-- "docs/argparse/bench_sweep_serve_sla.inc.md"
|
||||
|
||||
@ -4,6 +4,6 @@
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
## Arguments
|
||||
|
||||
--8<-- "docs/argparse/bench_throughput.md"
|
||||
--8<-- "docs/argparse/bench_throughput.inc.md"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# vllm chat
|
||||
|
||||
## Options
|
||||
## Arguments
|
||||
|
||||
--8<-- "docs/argparse/chat.md"
|
||||
--8<-- "docs/argparse/chat.inc.md"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# vllm complete
|
||||
|
||||
## Options
|
||||
## Arguments
|
||||
|
||||
--8<-- "docs/argparse/complete.md"
|
||||
--8<-- "docs/argparse/complete.inc.md"
|
||||
|
||||
@ -4,6 +4,6 @@
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
## Arguments
|
||||
|
||||
--8<-- "docs/argparse/run-batch.md"
|
||||
--8<-- "docs/argparse/run-batch.inc.md"
|
||||
|
||||
@ -4,6 +4,6 @@
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
## Arguments
|
||||
|
||||
--8<-- "docs/argparse/serve.md"
|
||||
--8<-- "docs/argparse/serve.inc.md"
|
||||
|
||||
@ -5,7 +5,7 @@ The `vllm serve` command is used to launch the OpenAI-compatible server.
|
||||
## CLI Arguments
|
||||
|
||||
The `vllm serve` command is used to launch the OpenAI-compatible server.
|
||||
To see the available options, take a look at the [CLI Reference](../cli/README.md#options)!
|
||||
To see the available options, take a look at the [CLI Reference](../cli/README.md)!
|
||||
|
||||
## Configuration file
|
||||
|
||||
|
||||
@ -983,7 +983,7 @@ each document has close to 512 tokens.
|
||||
|
||||
Please note that the `/v1/rerank` is also supported by embedding models. So if you're running
|
||||
with an embedding model, also set `--no_reranker`. Because in this case the query is
|
||||
treated as a individual prompt by the server, here we send `random_batch_size - 1` documents
|
||||
treated as an individual prompt by the server, here we send `random_batch_size - 1` documents
|
||||
to account for the extra prompt which is the query. The token accounting to report the
|
||||
throughput numbers correctly is also adjusted.
|
||||
|
||||
|
||||
@ -224,6 +224,6 @@ snakeviz expensive_function.prof
|
||||
|
||||
Leverage VLLM_GC_DEBUG environment variable to debug GC costs.
|
||||
|
||||
- VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elpased times
|
||||
- VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elapsed times
|
||||
- VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger to log top 5
|
||||
collected objects for each gc.collect
|
||||
|
||||
@ -128,7 +128,7 @@ A [CUDAGraphWrapper][vllm.compilation.cuda_graph.CUDAGraphWrapper] instance wrap
|
||||
3. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, the wrapper will perform CUDA Graphs capture (if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
The above steps are based on the assumption that the CUDA Graphs wrapper would directly trust what’s in the forward context (controlled by the dispatcher). This lets us simplify and cenralize the logic, reducing the complexity as well as the risk of mismatched state between the wrappers and the dispatcher. It also allows reusing the wrapper class for both `FULL` and `PIECEWISE` runtime modes. See the implementation [here](https://github.com/vllm-project/vllm/blob/f751e50b7a2aae3110d83ed0d88202fc91b3e78a/vllm/compilation/cuda_graph.py#L106).
|
||||
The above steps are based on the assumption that the CUDA Graphs wrapper would directly trust what’s in the forward context (controlled by the dispatcher). This lets us simplify and centralize the logic, reducing the complexity as well as the risk of mismatched state between the wrappers and the dispatcher. It also allows reusing the wrapper class for both `FULL` and `PIECEWISE` runtime modes. See the implementation [here](https://github.com/vllm-project/vllm/blob/f751e50b7a2aae3110d83ed0d88202fc91b3e78a/vllm/compilation/cuda_graph.py#L106).
|
||||
|
||||
#### Nested Wrapper design
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# IO Processor Plugins
|
||||
|
||||
IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output.
|
||||
IO Processor plugins are a feature that allows pre- and post-processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output.
|
||||
|
||||
When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggered via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint.
|
||||
|
||||
|
||||
@ -411,7 +411,7 @@ Logits processor `update_state()` implementations should assume the following mo
|
||||
|
||||
* **"Condense" the batch to be contiguous:** starting with the lowest-index empty slot (which was caused by a Remove), apply a Unidirectional Move from the current highest non-empty slot in the batch to fill the empty slot. Proceed with additional Unidirectional Move operations in order of increasing empty slot destination index and decreasing non-empty slot source index until the batch is contiguous
|
||||
|
||||
* **Shrink the batch:** a side-effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots
|
||||
* **Shrink the batch:** a side effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots
|
||||
|
||||
5. Reorder the batch for improved efficiency. Depending on the attention backend implementation and the current characteristics of the batch, zero or more Swap Move operations may be applied to reorder the batch
|
||||
|
||||
@ -548,7 +548,7 @@ Built-in logits processors are always loaded when the vLLM engine starts. See th
|
||||
|
||||
Review these logits processor implementations for guidance on writing built-in logits processors.
|
||||
|
||||
Additionally, the following logits-processor-like functionalities are hard-coded into the sampler and do not yet utilize the programming model described above. Most of them will be refactored to use the aforemented logits processor programming model.
|
||||
Additionally, the following logits-processor-like functionalities are hard-coded into the sampler and do not yet utilize the programming model described above. Most of them will be refactored to use the aforementioned logits processor programming model.
|
||||
|
||||
* Allowed token IDs
|
||||
|
||||
|
||||
@ -68,7 +68,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
|
||||
|
||||
## Fused MoE Experts Kernels
|
||||
|
||||
The are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adatpers so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties.
|
||||
The are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adapters so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties.
|
||||
|
||||
Each kernel must be provided with one of the supported input activation formats. Some flavors of kernels support both standard and batched formats through different entry points, e.g. `TritonExperts` and `BatchedTritonExperts`. Batched format kernels are currently only needed for matching with certain all2all backends, e.g. `pplx`, `DeepEPLLPrepareAndFinalize`.
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ You can use vLLM *custom arguments* to pass in arguments which are not part of t
|
||||
Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code.
|
||||
|
||||
!!! note
|
||||
Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise invalid custom arguments can cause unexpected behaviour.
|
||||
Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise, invalid custom arguments can cause unexpected behaviour.
|
||||
|
||||
## Offline Custom Arguments
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ Logits processor `update_state()` implementations should assume the following mo
|
||||
|
||||
* **"Condense" the batch to be contiguous:** starting with the lowest-index empty slot (which was caused by a Remove), apply a Unidirectional Move from the current highest non-empty slot in the batch to fill the empty slot. Proceed with additional Unidirectional Move operations in order of increasing empty slot destination index and decreasing non-empty slot source index until the batch is contiguous
|
||||
|
||||
* **Shrink the batch:** a side-effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots
|
||||
* **Shrink the batch:** a side effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots
|
||||
|
||||
5. Reorder the batch for improved efficiency. Depending on the attention backend implementation and the current characteristics of the batch, zero or more Swap Move operations may be applied to reorder the batch
|
||||
|
||||
@ -286,7 +286,7 @@ Once you have created a custom subclass (like `WrappedPerReqLogitsProcessor`) wh
|
||||
|
||||
## Ways to Load Your Custom Logits Processor in vLLM
|
||||
|
||||
Logits processors are loaded at initialization. Critically, the set of loaded logits processors cannot be modified after the vLLM engine finishes loading, and new logits logits processors cannot be loaded on-demand for individual requests.
|
||||
Logits processors are loaded at initialization. Critically, the set of loaded logits processors cannot be modified after the vLLM engine finishes loading, and new logits processors cannot be loaded on-demand for individual requests.
|
||||
|
||||
This section details different ways of making your logits processor visible to vLLM and triggering vLLM to load your logits processor.
|
||||
|
||||
@ -438,7 +438,7 @@ The examples below show how a user would pass a custom argument (`target_token`)
|
||||
|
||||
## Best Practices for Writing Custom Logits Processors
|
||||
|
||||
Once vLLM loads a logits processor during initialization, then vLLM will invoke `update_state()` and `apply()` against that logits processor in every engine step. Both methods operate on all requests which currently reside in the vLLM persistent batch. Thus it is important to implement these methods efficiently.
|
||||
Once vLLM loads a logits processor during initialization, then vLLM will invoke `update_state()` and `apply()` against that logits processor in every engine step. Both methods operate on all requests which currently reside in the vLLM persistent batch. Thus, it is important to implement these methods efficiently.
|
||||
|
||||
* Write efficient `apply()` and `update_state()` implementations in light of the fact that logits processors operate at batch granularity
|
||||
* For example, you may be able to use efficient vectorized operations to implement `apply()` or update internal state vectors in `update_state()`
|
||||
@ -465,4 +465,4 @@ Once vLLM loads a logits processor during initialization, then vLLM will invoke
|
||||
|
||||
* **Note:** for wrapped per-request logits processors, the `AdapterLogitsProcessor` base-class handles this by default
|
||||
|
||||
* `is_argmax_invariant()` can be hard-coded to `True` or `False` if the logits processor has consistent behavior. However the argmax invariance may also be determined programmatically (i.e. if your logits processor is user-customizable in some way that impacts whether the logits processor is argmax invariant). For this reason, `is_argmax_invariant()` is not a class method
|
||||
* `is_argmax_invariant()` can be hard-coded to `True` or `False` if the logits processor has consistent behavior. However, the argmax invariance may also be determined programmatically (i.e. if your logits processor is user-customizable in some way that impacts whether the logits processor is argmax invariant). For this reason, `is_argmax_invariant()` is not a class method
|
||||
|
||||
@ -91,6 +91,6 @@ Disaggregated prefilling is highly related to infrastructure, so vLLM relies on
|
||||
|
||||
We recommend three ways of implementations:
|
||||
|
||||
- **Fully-customized connector**: Implement your own `Connector`, and call third-party libraries to send and receive KV caches, and many many more (like editing vLLM's model input to perform customized prefilling, etc). This approach gives you the most control, but at the risk of being incompatible with future vLLM versions.
|
||||
- **Fully-customized connector**: Implement your own `Connector`, and call third-party libraries to send and receive KV caches, and many many more (like editing vLLM's model input to perform customized prefilling, etc.). This approach gives you the most control, but at the risk of being incompatible with future vLLM versions.
|
||||
- **Database-like connector**: Implement your own `LookupBuffer` and support the `insert` and `drop_select` APIs just like SQL.
|
||||
- **Distributed P2P connector**: Implement your own `Pipe` and support the `send_tensor` and `recv_tensor` APIs, just like `torch.distributed`.
|
||||
|
||||
@ -4,7 +4,7 @@ This document shows you how to use [LoRA adapters](https://arxiv.org/abs/2106.09
|
||||
|
||||
LoRA adapters can be used with any vLLM model that implements [SupportsLoRA][vllm.model_executor.models.interfaces.SupportsLoRA].
|
||||
|
||||
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
|
||||
Adapters can be efficiently served on a per-request basis with minimal overhead. First we download the adapter(s) and save
|
||||
them locally with
|
||||
|
||||
```python
|
||||
|
||||
@ -483,7 +483,7 @@ Then, you can use the OpenAI client as follows:
|
||||
)
|
||||
|
||||
# Single-image input inference
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
|
||||
@ -298,7 +298,7 @@ There are two steps to generate and deploy a mixed precision model quantized wit
|
||||
|
||||
Firstly, the layerwise mixed-precision configuration for a given LLM model is searched and then quantized using AMD Quark. We will provide a detailed tutorial with Quark APIs later.
|
||||
|
||||
As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are:
|
||||
As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benefits. They are:
|
||||
|
||||
- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
|
||||
- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
|
||||
|
||||
@ -97,14 +97,13 @@ Currently, there are no pre-built CPU wheels.
|
||||
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists, `auto` (by default), or `nobind` (to disable binding to individual CPU cores and to inherit user-defined OpenMP variables). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. If set to `nobind`, the number of OpenMP threads is determined by the standard `OMP_NUM_THREADS` environment variable.
|
||||
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
|
||||
- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence.
|
||||
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
|
||||
- `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False).
|
||||
|
||||
## FAQ
|
||||
|
||||
### Which `dtype` should be used?
|
||||
|
||||
- Currently vLLM CPU uses model default settings as `dtype`. However, due to unstable float16 support in torch CPU, it is recommended to explicitly set `dtype=bfloat16` if there are any performance or accuracy problem.
|
||||
- Currently, vLLM CPU uses model default settings as `dtype`. However, due to unstable float16 support in torch CPU, it is recommended to explicitly set `dtype=bfloat16` if there are any performance or accuracy problem.
|
||||
|
||||
### How to launch a vLLM service on CPU?
|
||||
|
||||
@ -191,10 +190,9 @@ vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel
|
||||
- GPTQ (x86 only)
|
||||
- compressed-tensor INT8 W8A8 (x86, s390x)
|
||||
|
||||
### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`?
|
||||
### (x86 only) What is the purpose of `VLLM_CPU_SGL_KERNEL`?
|
||||
|
||||
- Both of them require `amx` CPU flag.
|
||||
- `VLLM_CPU_MOE_PREPACK` can provide better performance for MoE models
|
||||
- `VLLM_CPU_SGL_KERNEL` can provide better performance for MoE models and small-batch scenarios.
|
||||
|
||||
### Why do I see `get_mempolicy: Operation not permitted` when running in Docker?
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
vLLM has experimental support for s390x architecture on IBM Z platform. For now, users must build from source to natively run on IBM Z platform.
|
||||
|
||||
Currently the CPU implementation for s390x architecture supports FP32 datatype only.
|
||||
Currently, the CPU implementation for s390x architecture supports FP32 datatype only.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
|
||||
@ -83,7 +83,7 @@ uv pip install dist/*.whl
|
||||
!!! example "Troubleshooting"
|
||||
- **NumPy ≥2.0 error**: Downgrade using `pip install "numpy<2.0"`.
|
||||
- **CMake picks up CUDA**: Add `CMAKE_DISABLE_FIND_PACKAGE_CUDA=ON` to prevent CUDA detection during CPU builds, even if CUDA is installed.
|
||||
- `AMD` requies at least 4th gen processors (Zen 4/Genoa) or higher to support [AVX512](https://www.phoronix.com/review/amd-zen4-avx512) to run vLLM on CPU.
|
||||
- `AMD` requires at least 4th gen processors (Zen 4/Genoa) or higher to support [AVX512](https://www.phoronix.com/review/amd-zen4-avx512) to run vLLM on CPU.
|
||||
- If you receive an error such as: `Could not find a version that satisfies the requirement torch==X.Y.Z+cpu+cpu`, consider updating [pyproject.toml](https://github.com/vllm-project/vllm/blob/main/pyproject.toml) to help pip resolve the dependency.
|
||||
```toml title="pyproject.toml"
|
||||
[build-system]
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import SUPPRESS, HelpFormatter
|
||||
from argparse import SUPPRESS, Action, HelpFormatter
|
||||
from collections.abc import Iterable
|
||||
from importlib.machinery import ModuleSpec
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pydantic_core import core_schema
|
||||
@ -19,6 +22,11 @@ ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse"
|
||||
sys.path.insert(0, str(ROOT_DIR))
|
||||
|
||||
|
||||
def mock_if_no_torch(mock_module: str, mock: MagicMock):
|
||||
if not importlib.util.find_spec("torch"):
|
||||
sys.modules[mock_module] = mock
|
||||
|
||||
|
||||
# Mock custom op code
|
||||
class MockCustomOp:
|
||||
@staticmethod
|
||||
@ -29,18 +37,21 @@ class MockCustomOp:
|
||||
return decorator
|
||||
|
||||
|
||||
noop = lambda *a, **k: None
|
||||
sys.modules["vllm._C"] = MagicMock()
|
||||
sys.modules["vllm.model_executor.custom_op"] = MagicMock(CustomOp=MockCustomOp)
|
||||
sys.modules["vllm.utils.torch_utils"] = MagicMock(direct_register_custom_op=noop)
|
||||
mock_if_no_torch("vllm._C", MagicMock())
|
||||
mock_if_no_torch("vllm.model_executor.custom_op", MagicMock(CustomOp=MockCustomOp))
|
||||
mock_if_no_torch(
|
||||
"vllm.utils.torch_utils", MagicMock(direct_register_custom_op=lambda *a, **k: None)
|
||||
)
|
||||
|
||||
|
||||
# Mock any version checks by reading from compiled CI requirements
|
||||
with open(ROOT_DIR / "requirements/test.txt") as f:
|
||||
VERSIONS = dict(line.strip().split("==") for line in f if "==" in line)
|
||||
importlib.metadata.version = lambda name: VERSIONS.get(name) or "0.0.0"
|
||||
|
||||
|
||||
# Make torch.nn.Parameter safe to inherit from
|
||||
sys.modules["torch.nn"] = MagicMock(Parameter=object)
|
||||
mock_if_no_torch("torch.nn", MagicMock(Parameter=object))
|
||||
|
||||
|
||||
class PydanticMagicMock(MagicMock):
|
||||
@ -49,31 +60,34 @@ class PydanticMagicMock(MagicMock):
|
||||
def __init__(self, *args, **kwargs):
|
||||
name = kwargs.pop("name", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__spec__ = importlib.machinery.ModuleSpec(name, None)
|
||||
self.__spec__ = ModuleSpec(name, None)
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type, handler):
|
||||
return core_schema.any_schema()
|
||||
|
||||
|
||||
def auto_mock(module, attr, max_mocks=100):
|
||||
def auto_mock(module_name: str, attr: str, max_mocks: int = 100):
|
||||
"""Function that automatically mocks missing modules during imports."""
|
||||
logger.info("Importing %s from %s", attr, module)
|
||||
logger.info("Importing %s from %s", attr, module_name)
|
||||
|
||||
for _ in range(max_mocks):
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# First treat attr as an attr, then as a submodule
|
||||
return getattr(
|
||||
importlib.import_module(module),
|
||||
attr,
|
||||
importlib.import_module(f"{module}.{attr}"),
|
||||
)
|
||||
if hasattr(module, attr):
|
||||
return getattr(module, attr)
|
||||
|
||||
return importlib.import_module(f"{module_name}.{attr}")
|
||||
except ModuleNotFoundError as e:
|
||||
assert e.name is not None
|
||||
logger.info("Mocking %s for argparse doc generation", e.name)
|
||||
sys.modules[e.name] = PydanticMagicMock(name=e.name)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to import %s.%s: %s", module, attr, e)
|
||||
except Exception:
|
||||
logger.exception("Failed to import %s.%s: %s", module_name, attr)
|
||||
|
||||
raise ImportError(
|
||||
f"Failed to import {module}.{attr} after mocking {max_mocks} imports"
|
||||
f"Failed to import {module_name}.{attr} after mocking {max_mocks} imports"
|
||||
)
|
||||
|
||||
|
||||
@ -91,21 +105,26 @@ ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand")
|
||||
CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand")
|
||||
openai_cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
|
||||
openai_run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
|
||||
FlexibleArgumentParser = auto_mock(
|
||||
"vllm.utils.argparse_utils", "FlexibleArgumentParser"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
else:
|
||||
FlexibleArgumentParser = auto_mock(
|
||||
"vllm.utils.argparse_utils", "FlexibleArgumentParser"
|
||||
)
|
||||
|
||||
|
||||
class MarkdownFormatter(HelpFormatter):
|
||||
"""Custom formatter that generates markdown for argument groups."""
|
||||
|
||||
def __init__(self, prog, starting_heading_level=3):
|
||||
super().__init__(prog, max_help_position=float("inf"), width=float("inf"))
|
||||
def __init__(self, prog: str, starting_heading_level: int = 3):
|
||||
super().__init__(prog, max_help_position=sys.maxsize, width=sys.maxsize)
|
||||
|
||||
self._section_heading_prefix = "#" * starting_heading_level
|
||||
self._argument_heading_prefix = "#" * (starting_heading_level + 1)
|
||||
self._markdown_output = []
|
||||
|
||||
def start_section(self, heading):
|
||||
def start_section(self, heading: str):
|
||||
if heading not in {"positional arguments", "options"}:
|
||||
heading_md = f"\n{self._section_heading_prefix} {heading}\n\n"
|
||||
self._markdown_output.append(heading_md)
|
||||
@ -113,14 +132,14 @@ class MarkdownFormatter(HelpFormatter):
|
||||
def end_section(self):
|
||||
pass
|
||||
|
||||
def add_text(self, text):
|
||||
def add_text(self, text: str):
|
||||
if text:
|
||||
self._markdown_output.append(f"{text.strip()}\n\n")
|
||||
|
||||
def add_usage(self, usage, actions, groups, prefix=None):
|
||||
pass
|
||||
|
||||
def add_arguments(self, actions):
|
||||
def add_arguments(self, actions: Iterable[Action]):
|
||||
for action in actions:
|
||||
if len(action.option_strings) == 0 or "--help" in action.option_strings:
|
||||
continue
|
||||
@ -169,7 +188,7 @@ def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser:
|
||||
# Auto-mock runtime imports
|
||||
if tb_list := traceback.extract_tb(e.__traceback__):
|
||||
path = Path(tb_list[-1].filename).relative_to(ROOT_DIR)
|
||||
auto_mock(module=".".join(path.parent.parts), attr=path.stem)
|
||||
auto_mock(module_name=".".join(path.parent.parts), attr=path.stem)
|
||||
return create_parser(add_cli_args, **kwargs)
|
||||
else:
|
||||
raise e
|
||||
@ -209,7 +228,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
|
||||
# Generate documentation for each parser
|
||||
for stem, parser in parsers.items():
|
||||
doc_path = ARGPARSE_DOC_DIR / f"{stem}.md"
|
||||
doc_path = ARGPARSE_DOC_DIR / f"{stem}.inc.md"
|
||||
# Specify encoding for building on Windows
|
||||
with open(doc_path, "w", encoding="utf-8") as f:
|
||||
f.write(super(type(parser), parser).format_help())
|
||||
|
||||
@ -351,6 +351,7 @@ th {
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|
|
||||
| `AfmoeForCausalLM` | Afmoe | TBA | ✅︎ | ✅︎ |
|
||||
| `ApertusForCausalLM` | Apertus | `swiss-ai/Apertus-8B-2509`, `swiss-ai/Apertus-70B-Instruct-2509`, etc. | ✅︎ | ✅︎ |
|
||||
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ |
|
||||
| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ |
|
||||
@ -670,7 +671,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
|
||||
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
|
||||
@ -685,7 +686,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
|
||||
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ |
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ |
|
||||
|
||||
@ -293,7 +293,7 @@ and passing a list of `messages` in the request. Refer to the examples below for
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="EMPTY",
|
||||
)
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Using vLLM
|
||||
|
||||
First, vLLM must be [installed](../getting_started/installation/) for your chosen device in either a Python or Docker environment.
|
||||
First, vLLM must be [installed](../getting_started/installation/README.md) for your chosen device in either a Python or Docker environment.
|
||||
|
||||
Then, vLLM supports the following usage patterns:
|
||||
|
||||
|
||||
@ -62,7 +62,7 @@ ray.init()
|
||||
|
||||
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
||||
# Learn more about Ray placement groups:
|
||||
# https://docs.ray.io/en/latest/placement-groups.html
|
||||
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
|
||||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||
ray.get(pg_inference.ready())
|
||||
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||
|
||||
162
examples/offline_inference/rlhf_online_quant.py
Normal file
162
examples/offline_inference/rlhf_online_quant.py
Normal file
@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
|
||||
|
||||
The script separates training and inference workloads onto distinct GPUs
|
||||
so that Ray can manage process placement and inter-process communication.
|
||||
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
|
||||
tensor-parallel vLLM inference engine occupies GPU 1–2.
|
||||
|
||||
The example performs the following steps:
|
||||
|
||||
* Load the training model on GPU 0.
|
||||
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
|
||||
and Ray placement groups.
|
||||
* Generate text from a list of prompts using the inference engine.
|
||||
* Update the weights of the training model and broadcast the updated weights
|
||||
to the inference engine by using a Ray collective RPC group. Note that
|
||||
for demonstration purposes we simply zero out the weights.
|
||||
|
||||
For a production-ready implementation that supports multiple training and
|
||||
inference replicas, see the OpenRLHF framework:
|
||||
https://github.com/OpenRLHF/OpenRLHF
|
||||
|
||||
This example assumes a single-node cluster with three GPUs, but Ray
|
||||
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
|
||||
workloads. Residual GPU activity interferes with vLLM memory profiling and
|
||||
causes unexpected behavior.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from rlhf_utils import stateless_init_process_group
|
||||
from torchao.core.config import config_to_dict
|
||||
from torchao.quantization import (
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
PerRow,
|
||||
)
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""Configure the vLLM worker for Ray placement group execution."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
|
||||
# so that vLLM can manage its own device placement within the worker.
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Load the OPT-125M model onto GPU 0 for the training workload.
|
||||
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
train_model.to("cuda:0")
|
||||
|
||||
# Initialize Ray and set the visible devices. The vLLM engine will
|
||||
# be placed on GPUs 1 and 2.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
|
||||
ray.init()
|
||||
|
||||
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
||||
# Learn more about Ray placement groups:
|
||||
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
|
||||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||
ray.get(pg_inference.ready())
|
||||
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg_inference,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=0,
|
||||
)
|
||||
|
||||
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
|
||||
# start-up latency.
|
||||
|
||||
# generate torchao quantization config for RL rollout
|
||||
# see https://github.com/vllm-project/vllm/pull/23014 for instructions to
|
||||
# use serialized config files instead of passing around json string
|
||||
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
|
||||
|
||||
json_str = json.dumps(config_to_dict(config))
|
||||
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=scheduling_inference,
|
||||
)(MyLLM).remote(
|
||||
model="facebook/opt-125m",
|
||||
hf_overrides={"quantization_config_dict_json": json_str},
|
||||
enforce_eager=True,
|
||||
worker_extension_cls="rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="ray",
|
||||
)
|
||||
|
||||
# Generate text from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Set up the communication channel between the training process and the
|
||||
# inference engine.
|
||||
master_address = get_ip()
|
||||
master_port = get_open_port()
|
||||
|
||||
handle = llm.collective_rpc.remote(
|
||||
"init_weight_update_group", args=(master_address, master_port, 1, 3)
|
||||
)
|
||||
|
||||
model_update_group = stateless_init_process_group(
|
||||
master_address, master_port, 0, 3, torch.device("cuda:0")
|
||||
)
|
||||
ray.get(handle)
|
||||
|
||||
# Simulate a training step by zeroing out all model weights.
|
||||
# In a real RLHF training loop the weights would be updated using the gradient
|
||||
# from an RL objective such as PPO on a reward model.
|
||||
for name, p in train_model.named_parameters():
|
||||
p.data.zero_()
|
||||
|
||||
# Synchronize the updated weights to the inference engine.
|
||||
for name, p in train_model.named_parameters():
|
||||
dtype_name = str(p.dtype).split(".")[-1]
|
||||
handle = llm.collective_rpc.remote(
|
||||
"update_weight", args=(name, dtype_name, p.shape)
|
||||
)
|
||||
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
|
||||
ray.get(handle)
|
||||
|
||||
# Verify that the inference weights have been updated.
|
||||
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
|
||||
|
||||
# Generate text with the updated model. The output is expected to be nonsense
|
||||
# because the weights are zero.
|
||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
print("-" * 50)
|
||||
for output in outputs_updated:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
@ -266,7 +266,7 @@ def get_query(modality: QueryModality):
|
||||
return ImageQuery(
|
||||
modality="image",
|
||||
image=fetch_image(
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg" # noqa: E501
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
|
||||
),
|
||||
)
|
||||
|
||||
@ -275,7 +275,7 @@ def get_query(modality: QueryModality):
|
||||
modality="text+image",
|
||||
text="A cat standing in the snow.",
|
||||
image=fetch_image(
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg" # noqa: E501
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/cat_snow.jpg" # noqa: E501
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ def run_text_only(model: str, max_completion_tokens: int) -> None:
|
||||
# Single-image input inference
|
||||
def run_single_image(model: str, max_completion_tokens: int) -> None:
|
||||
## Use image url in the payload
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
|
||||
@ -21,7 +21,7 @@ from PIL import Image
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
|
||||
|
||||
def create_chat_embeddings(
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "vLLM linting system has been moved from format.sh to pre-commit hooks."
|
||||
echo "Please run 'pip install -r requirements/lint.txt', followed by"
|
||||
echo "'pre-commit install' to install the pre-commit hooks."
|
||||
echo "Then linters will run automatically before each commit."
|
||||
@ -30,7 +30,7 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31
|
||||
partial-json-parser # used for parsing partial JSON outputs
|
||||
pyzmq >= 25.0.0
|
||||
msgspec
|
||||
gguf >= 0.13.0
|
||||
gguf >= 0.17.0
|
||||
mistral_common[image] >= 1.8.5
|
||||
opencv-python-headless >= 4.11.0 # required for video IO
|
||||
pyyaml
|
||||
|
||||
@ -4,8 +4,9 @@ packaging>=24.2
|
||||
setuptools>=77.0.3,<81.0.0
|
||||
setuptools-scm>=8
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.8.0+cpu; platform_machine == "x86_64"
|
||||
torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin"
|
||||
torch==2.8.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
|
||||
torch==2.9.0; platform_system == "Darwin"
|
||||
torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
|
||||
scons; platform_machine == "aarch64" # needed to build Arm Compute Library (ACL)
|
||||
wheel
|
||||
jinja2>=3.1.6
|
||||
|
||||
@ -22,7 +22,6 @@ datasets # for benchmark scripts
|
||||
|
||||
# Intel Extension for PyTorch, only for x86_64 CPUs
|
||||
intel-openmp==2024.2.1; platform_machine == "x86_64"
|
||||
intel_extension_for_pytorch==2.8.0; platform_machine == "x86_64"
|
||||
triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile.
|
||||
|
||||
# Use this to gather CPU info and optimize based on ARM Neoverse cores
|
||||
|
||||
17
setup.py
17
setup.py
@ -299,6 +299,20 @@ class cmake_build_ext(build_ext):
|
||||
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
||||
self.copy_file(file, dst_file)
|
||||
|
||||
if _is_cuda() or _is_hip():
|
||||
# copy vllm/third_party/triton_kernels/**/*.py from self.build_lib
|
||||
# to current directory so that they can be included in the editable
|
||||
# build
|
||||
print(
|
||||
f"Copying {self.build_lib}/vllm/third_party/triton_kernels "
|
||||
"to vllm/third_party/triton_kernels"
|
||||
)
|
||||
shutil.copytree(
|
||||
f"{self.build_lib}/vllm/third_party/triton_kernels",
|
||||
"vllm/third_party/triton_kernels",
|
||||
dirs_exist_ok=True,
|
||||
)
|
||||
|
||||
|
||||
class precompiled_build_ext(build_ext):
|
||||
"""Disables extension building when using precompiled binaries."""
|
||||
@ -633,6 +647,9 @@ ext_modules = []
|
||||
if _is_cuda() or _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
# Optional since this doesn't get built (produce an .so file). This is just
|
||||
# copying the relevant .py files from the source repository.
|
||||
ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True))
|
||||
|
||||
if _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
|
||||
|
||||
@ -20,13 +20,22 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ..utils import flat_product, multi_gpu_test
|
||||
|
||||
is_blackwell = lambda: current_platform.is_device_capability(100)
|
||||
"""Are we running on Blackwell, a lot of tests depend on it"""
|
||||
|
||||
|
||||
class Matches(NamedTuple):
|
||||
attention_fusion: int = 0
|
||||
allreduce_fusion: int = 0
|
||||
sequence_parallel: int = 0
|
||||
async_tp: int = 0
|
||||
|
||||
|
||||
class ModelBackendTestCase(NamedTuple):
|
||||
model_name: str
|
||||
model_kwargs: dict[str, Any]
|
||||
backend: AttentionBackendEnum
|
||||
attention_fusions: int
|
||||
allreduce_fusions: int | None = None
|
||||
matches: Matches
|
||||
|
||||
|
||||
MODELS_FP8: list[ModelBackendTestCase] = []
|
||||
@ -38,17 +47,33 @@ if current_platform.is_cuda():
|
||||
ModelBackendTestCase(
|
||||
# Use smaller model for L40s in CI
|
||||
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
# TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell
|
||||
# so FI attention+fp8_quant is at least tested once
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=AttentionBackendEnum.FLASHINFER
|
||||
if is_blackwell()
|
||||
else AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
attention_fusion=32,
|
||||
allreduce_fusion=65,
|
||||
sequence_parallel=65,
|
||||
async_tp=128,
|
||||
),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=AttentionBackendEnum.FLASHINFER,
|
||||
attention_fusions=48,
|
||||
allreduce_fusions=96,
|
||||
# TODO FlashInfer attn broken on Hopper with kvcache=fp8:
|
||||
# https://github.com/vllm-project/vllm/issues/28568
|
||||
# TODO FlashInfer attn broken on Blackwell for llama4:
|
||||
# https://github.com/vllm-project/vllm/issues/28604
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
attention_fusion=48,
|
||||
allreduce_fusion=96,
|
||||
sequence_parallel=96,
|
||||
async_tp=95, # mlp is moe, no fusion there
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -57,8 +82,12 @@ if current_platform.is_cuda():
|
||||
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=AttentionBackendEnum.FLASHINFER,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
matches=Matches(
|
||||
attention_fusion=32,
|
||||
allreduce_fusion=65,
|
||||
sequence_parallel=65,
|
||||
async_tp=128,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -68,15 +97,23 @@ if current_platform.is_cuda():
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=65,
|
||||
matches=Matches(
|
||||
attention_fusion=0,
|
||||
allreduce_fusion=65,
|
||||
sequence_parallel=65,
|
||||
async_tp=128,
|
||||
),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="Qwen/Qwen3-30B-A3B",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=97,
|
||||
matches=Matches(
|
||||
attention_fusion=0,
|
||||
allreduce_fusion=97,
|
||||
sequence_parallel=97,
|
||||
async_tp=96, # MLP is MoE, half the fusions of dense
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -86,19 +123,19 @@ elif current_platform.is_rocm():
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
matches=Matches(attention_fusion=32),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.ROCM_ATTN,
|
||||
attention_fusions=32,
|
||||
matches=Matches(attention_fusion=32),
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
attention_fusions=32,
|
||||
matches=Matches(attention_fusion=32),
|
||||
),
|
||||
]
|
||||
|
||||
@ -106,8 +143,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, "
|
||||
"attention_fusions, allreduce_fusions, custom_ops",
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
|
||||
# quant_fp4 only has the custom impl
|
||||
@ -118,15 +154,14 @@ def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: AttentionBackendEnum,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if backend == AttentionBackendEnum.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
not is_blackwell() or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
@ -169,12 +204,12 @@ def test_attn_quant(
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(compilation_config, model_name, **model_kwargs)
|
||||
|
||||
matches = re.findall(
|
||||
log_matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 1, log_holder.text
|
||||
assert int(matches[0]) == attention_fusions
|
||||
assert len(log_matches) == 1, log_holder.text
|
||||
assert int(log_matches[0]) == matches.attention_fusion
|
||||
|
||||
|
||||
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
|
||||
@ -187,8 +222,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, "
|
||||
"attention_fusions, allreduce_fusions, custom_ops",
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Toggle RMSNorm and QuantFP8 for FP8 models
|
||||
list(
|
||||
flat_product(
|
||||
@ -209,8 +243,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: AttentionBackendEnum,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
@ -219,6 +252,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
if "fp4" in model_name.lower() and not is_blackwell():
|
||||
pytest.skip("NVFP4 quant requires Blackwell")
|
||||
|
||||
if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell():
|
||||
# FlashInfer attn fusion requires Blackwell
|
||||
matches = matches._replace(attention_fusion=0)
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
@ -258,23 +298,135 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
run_model(
|
||||
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
||||
)
|
||||
matches = re.findall(
|
||||
log_matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 2, log_holder.text
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(matches[0]) == attention_fusions
|
||||
assert int(matches[1]) == attention_fusions
|
||||
assert int(log_matches[0]) == matches.attention_fusion
|
||||
assert int(log_matches[1]) == matches.attention_fusion
|
||||
|
||||
matches = re.findall(
|
||||
log_matches = re.findall(
|
||||
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 2, log_holder.text
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(matches[0]) == allreduce_fusions
|
||||
assert int(matches[1]) == allreduce_fusions
|
||||
assert int(log_matches[0]) == matches.allreduce_fusion
|
||||
assert int(log_matches[1]) == matches.allreduce_fusion
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Toggle RMSNorm and QuantFP8 for FP8 models
|
||||
list(
|
||||
flat_product(
|
||||
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
|
||||
)
|
||||
)
|
||||
# Toggle RMSNorm for FP4 models and unquant models
|
||||
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="sequence parallel only tested on CUDA",
|
||||
)
|
||||
def test_tp2_attn_quant_async_tp(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if is_blackwell():
|
||||
# TODO: https://github.com/vllm-project/vllm/issues/27893
|
||||
pytest.skip("Blackwell is not supported for AsyncTP pass")
|
||||
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
if "fp4" in model_name.lower() and not is_blackwell():
|
||||
pytest.skip("NVFP4 quant requires Blackwell")
|
||||
|
||||
if backend == AttentionBackendEnum.FLASHINFER:
|
||||
if not has_flashinfer():
|
||||
pytest.skip("FlashInfer backend requires flashinfer installed")
|
||||
if not is_blackwell():
|
||||
# FlashInfer attn fusion requires Blackwell
|
||||
matches = matches._replace(attention_fusion=0)
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
custom_ops=custom_ops_list,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(
|
||||
enable_attn_fusion=True,
|
||||
enable_noop=True,
|
||||
enable_sequence_parallelism=True,
|
||||
enable_async_tp=True,
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(
|
||||
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
||||
)
|
||||
log_matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(log_matches[0]) == matches.attention_fusion
|
||||
assert int(log_matches[1]) == matches.attention_fusion
|
||||
|
||||
log_matches = re.findall(
|
||||
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(log_matches[0]) == matches.sequence_parallel
|
||||
assert int(log_matches[1]) == matches.sequence_parallel
|
||||
|
||||
log_matches = re.findall(
|
||||
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 2, log_holder.text
|
||||
|
||||
assert int(log_matches[0]) == matches.async_tp
|
||||
assert int(log_matches[1]) == matches.async_tp
|
||||
|
||||
|
||||
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
||||
|
||||
@ -5,15 +5,15 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.fx_utils import find_auto_fn
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CUDAGraphMode,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
@ -27,6 +27,7 @@ from vllm.distributed.parallel_state import (
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
@ -43,172 +44,157 @@ prompts = [
|
||||
]
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size))
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
"""
|
||||
Forward pass implementing the operations in the FX graph
|
||||
def forward(self, x):
|
||||
z = torch.relu(x)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor
|
||||
residual: Residual tensor from previous layer
|
||||
z2 = torch.mm(y, self.w[0])
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
|
||||
Returns:
|
||||
Tuple containing the output tensor
|
||||
"""
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
# matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
z3 = torch.mm(y2, self.w[1])
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
|
||||
# Tensor parallel all-reduce
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm)
|
||||
y3, resid = self.norm[2](x3, resid)
|
||||
|
||||
# layer normalization
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
z4 = torch.mm(y3, self.w[2])
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
|
||||
return norm_output, residual_output
|
||||
y4, resid = self.norm[3](x4, resid)
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
torch.ops.vllm.all_gather.default,
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
return [torch.ops._C.fused_add_rms_norm.default]
|
||||
if RMSNorm.enabled():
|
||||
return [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class TestQuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size)), requires_grad=False
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size)
|
||||
.to(dtype=current_platform.fp8_dtype())
|
||||
.t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
# Create a weight that is compatible with torch._scaled_mm,
|
||||
# which expects a column-major layout.
|
||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
"""
|
||||
Forward pass implementing the operations in the FX graph
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor
|
||||
residual: Residual tensor from previous layer
|
||||
|
||||
Returns:
|
||||
Tuple containing the output tensor
|
||||
"""
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
|
||||
# matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
|
||||
# Tensor parallel all-reduce
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm)
|
||||
|
||||
# layer normalization
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
|
||||
# scaled_mm with static input quantization
|
||||
fp8_linear_result = self.fp8_linear.apply(
|
||||
norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.scale.to(norm_output.device),
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
return fp8_linear_result, residual_output
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
|
||||
# The following are only removed if fusion happens
|
||||
if (
|
||||
self.vllm_config
|
||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||
):
|
||||
ops_to_remove.extend(
|
||||
[
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
)
|
||||
return ops_to_remove
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
ops_to_add = [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
return [
|
||||
torch.ops.vllm.all_gather.default,
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
]
|
||||
# The following is only added if fusion happens
|
||||
if (
|
||||
self.vllm_config
|
||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||
):
|
||||
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
|
||||
return ops_to_add
|
||||
|
||||
def ops_in_model(self):
|
||||
if (
|
||||
self.vllm_config
|
||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
||||
):
|
||||
# If fusion happens, the fused op is the one
|
||||
# we check for (de)functionalization
|
||||
if self.vllm_config.compilation_config.pass_config.enable_fusion:
|
||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
||||
else:
|
||||
# If no fusion, the original ops are checked
|
||||
elif RMSNorm.enabled():
|
||||
return [
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
# TODO functionalization pass does not handle this yet
|
||||
# torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
elif self.fp8_linear.quant_fp8.enabled():
|
||||
return [
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel])
|
||||
@pytest.mark.parametrize(
|
||||
"test_model_cls, custom_ops",
|
||||
[
|
||||
(TestAllReduceRMSNormModel, "+rms_norm"),
|
||||
(TestAllReduceRMSNormModel, "-rms_norm"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
||||
@pytest.mark.parametrize("dynamic", [False, True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_sequence_parallelism_pass(
|
||||
test_model_cls: type[torch.nn.Module],
|
||||
custom_ops: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
num_processes = 2
|
||||
|
||||
@ -220,11 +206,13 @@ def test_sequence_parallelism_pass(
|
||||
args=(
|
||||
num_processes,
|
||||
test_model_cls,
|
||||
custom_ops,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
enable_fusion,
|
||||
dynamic,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
@ -236,11 +224,13 @@ def sequence_parallelism_pass_on_test_model(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
test_model_cls: type[torch.nn.Module],
|
||||
custom_ops: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
@ -264,12 +254,16 @@ def sequence_parallelism_pass_on_test_model(
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
compilation_config = CompilationConfig(
|
||||
splitting_ops=[], # avoid automatic rms_norm enablement
|
||||
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
|
||||
custom_ops=custom_ops_list,
|
||||
pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True,
|
||||
)
|
||||
),
|
||||
) # NoOp needed for fusion
|
||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
@ -289,7 +283,6 @@ def sequence_parallelism_pass_on_test_model(
|
||||
with set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||
@ -310,38 +303,29 @@ def sequence_parallelism_pass_on_test_model(
|
||||
|
||||
passes_for_backend.append(cleanup_pass)
|
||||
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
backend = TestBackend(*passes_for_backend)
|
||||
|
||||
model = test_model_cls(hidden_size, hidden_size * 2)
|
||||
model = test_model_cls(hidden_size)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||
compiled_model_no_func(hidden_states, residual)
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
|
||||
assert sequence_parallelism_pass.matched_count == 1
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
assert sequence_parallelism_pass.matched_count == 4
|
||||
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
for op in model.ops_in_model_before():
|
||||
assert backend.op_count(op, before=True) == 4
|
||||
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
||||
for op in model.ops_in_model_after():
|
||||
assert backend.op_count(op, before=False) == 4
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model():
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in model.ops_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model())
|
||||
find_auto_fn(backend.graph_post_pass.nodes, op)
|
||||
|
||||
437
tests/distributed/test_multiproc_executor.py
Normal file
437
tests/distributed/test_multiproc_executor.py
Normal file
@ -0,0 +1,437 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Integration tests for MultiprocExecutor at the executor level.
|
||||
This test directly tests the executor without going through the LLM interface,
|
||||
focusing on executor initialization, RPC calls, and distributed execution.
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
tensor_parallel_size: int = 1,
|
||||
pipeline_parallel_size: int = 1,
|
||||
max_model_len: int = 256,
|
||||
gpu_memory_utilization: float = 0.3,
|
||||
distributed_executor_backend: str = "mp",
|
||||
nnodes: int = 1,
|
||||
node_rank: int = 0,
|
||||
master_port: int = 0,
|
||||
) -> VllmConfig:
|
||||
"""Create a VllmConfig for testing using EngineArgs."""
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
# Override distributed node settings if needed
|
||||
if nnodes > 1 or node_rank > 0:
|
||||
vllm_config.parallel_config.nnodes = nnodes
|
||||
vllm_config.parallel_config.node_rank = node_rank
|
||||
vllm_config.parallel_config.master_port = master_port
|
||||
if nnodes > 1:
|
||||
vllm_config.parallel_config.disable_custom_all_reduce = True
|
||||
|
||||
return vllm_config
|
||||
|
||||
|
||||
def create_test_scheduler_output(num_requests: int = 1) -> SchedulerOutput:
|
||||
"""Create a minimal SchedulerOutput for testing."""
|
||||
# This is a simplified version - in practice you'd need proper
|
||||
# SchedulerOutput construction based on the actual vLLM v1 API
|
||||
return SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_resumed_reqs=[],
|
||||
scheduled_running_reqs=[],
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
def test_multiproc_executor_initialization():
|
||||
"""Test that MultiprocExecutor can be initialized with proper config."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
# Create executor - this should initialize workers
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify executor properties
|
||||
assert executor.world_size == 1, "World size should be 1 for single GPU"
|
||||
assert executor.local_world_size == 1, "Local world size should be 1"
|
||||
assert hasattr(executor, "workers"), "Executor should have workers"
|
||||
assert len(executor.workers) == 1, "Should have 1 worker for single GPU"
|
||||
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_initialization_tensor_parallel():
|
||||
"""Test MultiprocExecutor initialization with tensor parallelism."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify executor properties
|
||||
assert executor.world_size == 2, "World size should be 2 for TP=2"
|
||||
assert executor.local_world_size == 2, "Local world size should be 2"
|
||||
assert len(executor.workers) == 2, "Should have 2 workers for TP=2"
|
||||
|
||||
# Verify output rank calculation
|
||||
output_rank = executor._get_output_rank()
|
||||
assert output_rank == 0, "Output rank should be 0 for TP=2, PP=1"
|
||||
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_collective_rpc():
|
||||
"""Test collective RPC calls to all workers."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Test check_health RPC - should work without errors
|
||||
executor.check_health()
|
||||
|
||||
# Test that RPC works correctly
|
||||
# Note: We're just testing that the RPC mechanism works,
|
||||
# not testing actual model execution here
|
||||
assert not executor.is_failed, "Executor should not be in failed state"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
def test_multiproc_executor_failure_callback():
|
||||
"""Test failure callback registration and invocation."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Test callback registration
|
||||
callback_invoked = []
|
||||
|
||||
def test_callback():
|
||||
callback_invoked.append(True)
|
||||
|
||||
# Register callback
|
||||
executor.register_failure_callback(test_callback)
|
||||
|
||||
# Callback should not be invoked yet
|
||||
assert len(callback_invoked) == 0, "Callback should not be invoked immediately"
|
||||
|
||||
# Simulate failure
|
||||
executor.is_failed = True
|
||||
|
||||
# Register another callback - should be invoked immediately
|
||||
executor.register_failure_callback(test_callback)
|
||||
assert len(callback_invoked) == 1, (
|
||||
"Callback should be invoked when executor is failed"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_worker_monitor():
|
||||
"""Test that worker monitor is set up correctly."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Verify all worker processes are alive
|
||||
for worker in executor.workers:
|
||||
assert worker.proc.is_alive(), f"Worker rank {worker.rank} should be alive"
|
||||
|
||||
# Verify executor is not in failed state
|
||||
assert not executor.is_failed, "Executor should not be in failed state"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
# After shutdown, workers should be terminated
|
||||
import time
|
||||
|
||||
time.sleep(0.5) # Give processes time to terminate
|
||||
for worker in executor.workers:
|
||||
assert not worker.proc.is_alive(), (
|
||||
f"Worker rank {worker.rank} should terminate after shutdown"
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_get_response_message_queues():
|
||||
"""Test message queue retrieval for different ranks."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Get all message queues
|
||||
all_queues = executor.get_response_mqs()
|
||||
assert len(all_queues) == 2, "Should have 2 message queues for 2 workers"
|
||||
|
||||
# Get message queue for specific rank
|
||||
rank0_queue = executor.get_response_mqs(unique_reply_rank=0)
|
||||
assert len(rank0_queue) == 1, "Should have 1 message queue for rank 0"
|
||||
|
||||
rank1_queue = executor.get_response_mqs(unique_reply_rank=1)
|
||||
assert len(rank1_queue) == 1, "Should have 1 message queue for rank 1"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
def test_multiproc_executor_shutdown_cleanup():
|
||||
"""Test that shutdown properly cleans up resources."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify executor is set up
|
||||
assert hasattr(executor, "workers"), "Executor should have workers"
|
||||
assert len(executor.workers) > 0, "Should have at least one worker"
|
||||
|
||||
# Shutdown
|
||||
executor.shutdown()
|
||||
|
||||
# Verify cleanup
|
||||
import time
|
||||
|
||||
time.sleep(0.5) # Give processes time to terminate
|
||||
|
||||
for worker in executor.workers:
|
||||
assert not worker.proc.is_alive(), "Worker processes should be terminated"
|
||||
|
||||
# Verify shutdown event is set
|
||||
assert executor.shutdown_event.is_set(), "Shutdown event should be set"
|
||||
|
||||
# Multiple shutdowns should be safe (idempotent)
|
||||
executor.shutdown()
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_multiproc_executor_pipeline_parallel():
|
||||
"""Test MultiprocExecutor with pipeline parallelism."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Verify executor properties
|
||||
assert executor.world_size == 4, "World size should be 4 for TP=2, PP=2"
|
||||
assert len(executor.workers) == 4, "Should have 4 workers"
|
||||
|
||||
# Verify output rank calculation
|
||||
# For TP=2, PP=2: output should be from the last PP stage (ranks 2-3)
|
||||
# Specifically rank 2 (first rank of last PP stage)
|
||||
output_rank = executor._get_output_rank()
|
||||
assert output_rank == 2, "Output rank should be 2 (first rank of last PP stage)"
|
||||
|
||||
# Verify max_concurrent_batches for pipeline parallel
|
||||
assert executor.max_concurrent_batches == 2, (
|
||||
"Max concurrent batches should equal PP size"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
def test_multiproc_executor_properties():
|
||||
"""Test various executor properties and configurations."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Test supports_pp property
|
||||
assert MultiprocExecutor.supports_pp is True, (
|
||||
"MultiprocExecutor should support pipeline parallelism"
|
||||
)
|
||||
|
||||
# Test world_size calculation
|
||||
assert executor.world_size == (
|
||||
executor.parallel_config.tensor_parallel_size
|
||||
* executor.parallel_config.pipeline_parallel_size
|
||||
), "World size should equal TP * PP"
|
||||
|
||||
# Test local_world_size calculation
|
||||
assert executor.local_world_size == (
|
||||
executor.parallel_config.world_size // executor.parallel_config.nnodes
|
||||
), "Local world size should be world_size / nnodes"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_multiproc_executor_multi_node():
|
||||
"""
|
||||
Test MultiprocExecutor with multi-node configuration.
|
||||
This simulates 2 nodes with TP=4:
|
||||
- Node 0 (rank 0): Uses GPUs 0,1 (CUDA_VISIBLE_DEVICES=0,1) with TP=2
|
||||
- Node 1 (rank 1): Uses GPUs 2,3 (CUDA_VISIBLE_DEVICES=2,3) with TP=2
|
||||
Total world_size = 4, nnodes = 2
|
||||
"""
|
||||
port = get_open_port()
|
||||
# symm_mem does not work for simulating multi instance in single node
|
||||
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||
|
||||
def run_node(node_rank: int, result_queue: multiprocessing.Queue, port: int):
|
||||
"""Run a single node's executor."""
|
||||
executor = None
|
||||
try:
|
||||
# Set CUDA_VISIBLE_DEVICES for this node
|
||||
if node_rank == 0:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
|
||||
|
||||
# Create config for this node
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=4, # Total TP across all nodes
|
||||
pipeline_parallel_size=1,
|
||||
nnodes=2, # 2 nodes
|
||||
node_rank=node_rank,
|
||||
master_port=port, # same port
|
||||
)
|
||||
|
||||
# Create executor for this node
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify node-specific properties
|
||||
assert executor.world_size == 4, (
|
||||
f"World size should be 4 on node {node_rank}"
|
||||
)
|
||||
assert executor.local_world_size == 2, (
|
||||
f"Local world size should be 2 on node {node_rank}"
|
||||
)
|
||||
assert len(executor.workers) == 2, (
|
||||
f"Should have 2 local workers on node {node_rank}"
|
||||
)
|
||||
|
||||
# Verify worker ranks are correct for this node
|
||||
expected_ranks = [node_rank * 2, node_rank * 2 + 1]
|
||||
actual_ranks = sorted([w.rank for w in executor.workers])
|
||||
assert actual_ranks == expected_ranks, (
|
||||
f"Node {node_rank} should have workers "
|
||||
f"with ranks {expected_ranks}, got {actual_ranks}"
|
||||
)
|
||||
# Verify all workers are alive
|
||||
for worker in executor.workers:
|
||||
assert worker.proc.is_alive(), (
|
||||
f"Worker rank {worker.rank} should be alive on node {node_rank}"
|
||||
)
|
||||
# executor.gen
|
||||
# Put success result in queue BEFORE shutdown to avoid hanging
|
||||
result_queue.put({"node": node_rank, "success": True})
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
executor.shutdown()
|
||||
except Exception as e:
|
||||
# Put failure result in queue
|
||||
result_queue.put({"node": node_rank, "success": False, "error": str(e)})
|
||||
raise e
|
||||
finally:
|
||||
if executor is not None:
|
||||
executor.shutdown()
|
||||
|
||||
# Create a queue to collect results from both processes
|
||||
result_queue: multiprocessing.Queue[dict[str, int | bool]] = multiprocessing.Queue()
|
||||
|
||||
# Start both node processes
|
||||
processes = []
|
||||
for node_rank in range(2):
|
||||
p = multiprocessing.Process(
|
||||
target=run_node,
|
||||
args=(node_rank, result_queue, port),
|
||||
name=f"Node{node_rank}",
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
# Wait for both processes to complete
|
||||
all_completed = True
|
||||
for p in processes:
|
||||
p.join(timeout=60)
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=20)
|
||||
if p.is_alive():
|
||||
p.kill()
|
||||
p.join()
|
||||
all_completed = False
|
||||
|
||||
# Check results from both nodes
|
||||
results: list[dict[str, int | bool]] = []
|
||||
while len(results) < 2:
|
||||
try:
|
||||
result = result_queue.get(timeout=1)
|
||||
results.append(result)
|
||||
except Exception:
|
||||
pass
|
||||
assert all_completed, "Not all processes completed successfully"
|
||||
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
|
||||
assert results[0]["success"], f"Node 0 failed: {results[0]}"
|
||||
assert results[1]["success"], f"Node 1 failed: {results[1]}"
|
||||
@ -18,6 +18,7 @@ import pytest
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
@ -161,6 +162,7 @@ def _compare_sp(
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available: int,
|
||||
use_inductor_graph_partition: bool,
|
||||
enable_async_tp: bool,
|
||||
*,
|
||||
method: Literal["generate", "encode"],
|
||||
is_multimodal: bool,
|
||||
@ -244,10 +246,10 @@ def _compare_sp(
|
||||
|
||||
compilation_config = {
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"custom_ops": ["+rms_norm"],
|
||||
"compile_sizes": [4, 8],
|
||||
"pass_config": {
|
||||
"enable_sequence_parallelism": True,
|
||||
"enable_async_tp": enable_async_tp,
|
||||
"enable_fusion": enable_fusion,
|
||||
"enable_noop": True,
|
||||
},
|
||||
@ -307,6 +309,7 @@ SP_TEST_MODELS = [
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||
@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
|
||||
@create_new_process_for_each_test()
|
||||
def test_tp_sp_generation(
|
||||
model_id: str,
|
||||
@ -316,10 +319,19 @@ def test_tp_sp_generation(
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available,
|
||||
use_inductor_graph_partition: bool,
|
||||
enable_async_tp: bool,
|
||||
):
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
# Skip FP8 SP-only test on sm89 (compute capability 8.9)
|
||||
if (
|
||||
"fp8" in model_id.lower()
|
||||
and current_platform.get_device_capability() < (9, 0)
|
||||
and (not enable_async_tp)
|
||||
):
|
||||
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
|
||||
|
||||
_compare_sp(
|
||||
model_id,
|
||||
parallel_setup,
|
||||
@ -328,6 +340,7 @@ def test_tp_sp_generation(
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
use_inductor_graph_partition,
|
||||
enable_async_tp=enable_async_tp,
|
||||
method="generate",
|
||||
is_multimodal=False,
|
||||
)
|
||||
|
||||
@ -17,7 +17,7 @@ def chat_server_with_force_include_usage(request): # noqa: F811
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
"--max-num-seqs",
|
||||
"1",
|
||||
"4",
|
||||
"--enable-force-include-usage",
|
||||
"--port",
|
||||
"55857",
|
||||
@ -78,7 +78,7 @@ def transcription_server_with_force_include_usage():
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-num-seqs",
|
||||
"1",
|
||||
"4",
|
||||
"--enforce-eager",
|
||||
"--enable-force-include-usage",
|
||||
"--gpu-memory-utilization",
|
||||
|
||||
@ -16,6 +16,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm import version
|
||||
|
||||
from ...conftest import LocalAssetServer
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODELS = {
|
||||
@ -69,7 +70,6 @@ async def client(server):
|
||||
|
||||
|
||||
_PROMPT = "Hello my name is Robert and I love magic"
|
||||
_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
|
||||
|
||||
def _get_expected_values(num_requests: int, prompt_ids: list[int], max_tokens: int):
|
||||
@ -250,6 +250,7 @@ HIDDEN_DEPRECATED_METRICS: list[str] = [
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_exist(
|
||||
local_asset_server: LocalAssetServer,
|
||||
server: RemoteOpenAIServer,
|
||||
client: openai.AsyncClient,
|
||||
model_key: str,
|
||||
@ -265,13 +266,21 @@ async def test_metrics_exist(
|
||||
temperature=0.0,
|
||||
)
|
||||
else:
|
||||
# https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg
|
||||
await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": _IMAGE_URL}},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": local_asset_server.url_for(
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
),
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
|
||||
@ -17,10 +17,10 @@ MAXIMUM_IMAGES = 2
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
TEST_IMAGE_ASSETS = [
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
"Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
|
||||
"1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
|
||||
"RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
"Grayscale_8bits_palette_sample_image.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/Grayscale_8bits_palette_sample_image.png",
|
||||
"1280px-Venn_diagram_rgb.svg.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/1280px-Venn_diagram_rgb.svg.png",
|
||||
"RGBA_comp.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/RGBA_comp.png",
|
||||
]
|
||||
|
||||
EXPECTED_MM_BEAM_SEARCH_RES = [
|
||||
|
||||
@ -19,10 +19,10 @@ assert vlm2vec_jinja_path.exists()
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
TEST_IMAGE_ASSETS = [
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
"Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
|
||||
"1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
|
||||
"RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
"Grayscale_8bits_palette_sample_image.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/Grayscale_8bits_palette_sample_image.png",
|
||||
"1280px-Venn_diagram_rgb.svg.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/1280px-Venn_diagram_rgb.svg.png",
|
||||
"RGBA_comp.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/RGBA_comp.png",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -150,8 +150,8 @@ def test_merge_attn_states(
|
||||
output_torch = output.clone()
|
||||
output_lse_torch = output_lse.clone()
|
||||
total_time_torch_kernel = 0
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
|
||||
# 0. Run the Torch kernel
|
||||
prefix_lse_torch = prefix_lse.clone()
|
||||
@ -188,8 +188,8 @@ def test_merge_attn_states(
|
||||
output_lse_ref_triton = output_lse.clone()
|
||||
|
||||
total_time_triton_kernel = 0
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
|
||||
for _ in range(warmup_times):
|
||||
merge_attn_states_triton(
|
||||
|
||||
@ -40,8 +40,6 @@ NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [1, 2, 6]
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -33,8 +33,6 @@ if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
# Test configurations
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
|
||||
|
||||
@ -42,8 +42,6 @@ MNK_FACTORS = [
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
@ -7,6 +7,7 @@ fp8 block-quantized case.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
@ -14,6 +15,7 @@ from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
@ -61,6 +63,23 @@ requires_deep_gemm = pytest.mark.skipif(
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_dp_metadata(M: int, world_size: int):
|
||||
num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=M,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
@ -285,18 +304,21 @@ def deepep_deepgemm_moe_impl(
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
out = mk.forward(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
with with_dp_metadata(
|
||||
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
|
||||
):
|
||||
out = mk.forward(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -45,8 +45,6 @@ MNK_FACTORS = [
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def quant_fp8_per_tensor_batches(a):
|
||||
@ -79,10 +77,14 @@ class TestData:
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, reorder: bool
|
||||
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
|
||||
) -> "TestData":
|
||||
is_gated = activation != "relu2_no_mul"
|
||||
|
||||
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
|
||||
w13 = torch.randn(
|
||||
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Scale to fp8
|
||||
@ -192,18 +194,22 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
|
||||
def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
activation: str,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
|
||||
td = TestData.make_moe_tensors_8bit(
|
||||
m, k, n, e, reorder=False, activation=activation
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
@ -235,7 +241,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
@ -255,7 +261,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
td.layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation="silu",
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
|
||||
@ -81,8 +81,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def run_moe_test(
|
||||
|
||||
@ -192,8 +192,6 @@ def pplx_cutlass_moe(
|
||||
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def _pplx_moe(
|
||||
|
||||
@ -81,8 +81,6 @@ TOP_KS = [1, 2, 6]
|
||||
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def torch_prepare(
|
||||
|
||||
@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
|
||||
@ -29,8 +29,6 @@ if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
# Test configurations
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33, 64, 222]
|
||||
|
||||
169
tests/model_executor/test_eagle_quantization.py
Normal file
169
tests/model_executor/test_eagle_quantization.py
Normal file
@ -0,0 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.model_executor.models.utils import get_draft_quant_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DEVICES = (
|
||||
[f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
if current_platform.is_cuda_alike()
|
||||
else ["cpu"]
|
||||
)
|
||||
|
||||
|
||||
def test_get_draft_quant_config_with_draft_model():
|
||||
mock_draft_model_config = Mock(spec=ModelConfig)
|
||||
mock_load_config = Mock(spec=LoadConfig)
|
||||
mock_speculative_config = Mock(spec=SpeculativeConfig)
|
||||
mock_speculative_config.draft_model_config = mock_draft_model_config
|
||||
|
||||
mock_vllm_config = Mock(spec=VllmConfig)
|
||||
mock_vllm_config.speculative_config = mock_speculative_config
|
||||
mock_vllm_config.load_config = mock_load_config
|
||||
|
||||
mock_quant_config = Mock()
|
||||
with patch.object(
|
||||
VllmConfig, "get_quantization_config", return_value=mock_quant_config
|
||||
):
|
||||
result = get_draft_quant_config(mock_vllm_config)
|
||||
|
||||
# Verify the function calls get_quantization_config with draft model config
|
||||
VllmConfig.get_quantization_config.assert_called_once_with(
|
||||
mock_draft_model_config, mock_load_config
|
||||
)
|
||||
assert result == mock_quant_config
|
||||
|
||||
|
||||
def test_get_draft_quant_config_without_draft_model():
|
||||
mock_speculative_config = Mock(spec=SpeculativeConfig)
|
||||
mock_speculative_config.draft_model_config = None
|
||||
|
||||
mock_vllm_config = Mock(spec=VllmConfig)
|
||||
mock_vllm_config.speculative_config = mock_speculative_config
|
||||
mock_vllm_config.load_config = Mock(spec=LoadConfig)
|
||||
|
||||
result = get_draft_quant_config(mock_vllm_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_fc_layer_quant_config_usage(dist_init, device) -> None:
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
|
||||
input_size = 256
|
||||
output_size = 128
|
||||
|
||||
fc_no_quant = ReplicatedLinear(
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=False,
|
||||
params_dtype=torch.float16,
|
||||
quant_config=None,
|
||||
prefix="fc",
|
||||
)
|
||||
|
||||
assert fc_no_quant.quant_config is None
|
||||
assert fc_no_quant.input_size == input_size
|
||||
assert fc_no_quant.output_size == output_size
|
||||
|
||||
mock_quant_config = Mock()
|
||||
fc_with_quant = ReplicatedLinear(
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=False,
|
||||
params_dtype=torch.float16,
|
||||
quant_config=mock_quant_config,
|
||||
prefix="fc",
|
||||
)
|
||||
|
||||
assert fc_with_quant.quant_config == mock_quant_config
|
||||
|
||||
# Check forward pass
|
||||
x = torch.randn(2, input_size, dtype=torch.float16)
|
||||
output, _ = fc_no_quant(x)
|
||||
assert output.shape == (2, output_size)
|
||||
|
||||
|
||||
def test_kv_cache_scale_name_handling():
|
||||
# Mock a quant config that supports cache scales
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
|
||||
|
||||
# Condition check in load_weights
|
||||
name = "layers.0.self_attn.k_proj.weight"
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
|
||||
# Check if get_cache_scale is called and returns expected value
|
||||
mock_quant_config.get_cache_scale.assert_called_once_with(name)
|
||||
assert scale_name == "layers.0.self_attn.kv_scale"
|
||||
|
||||
|
||||
def test_kv_cache_scale_name_no_scale():
|
||||
# Mock a quant config that returns None for get_cache_scale
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value=None)
|
||||
|
||||
name = "layers.0.mlp.gate_proj.weight"
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
|
||||
# Should return None for weights that don't have cache scales
|
||||
assert scale_name is None
|
||||
|
||||
|
||||
def test_maybe_remap_kv_scale_name():
|
||||
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
|
||||
|
||||
params_dict = {
|
||||
"layers.0.self_attn.kv_scale": Mock(),
|
||||
"layers.1.self_attn.kv_scale": Mock(),
|
||||
}
|
||||
|
||||
name = "layers.0.self_attn.some_scale"
|
||||
remapped = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
assert remapped in params_dict or remapped == name or remapped is None
|
||||
|
||||
|
||||
def test_load_weights_kv_scale_handling():
|
||||
kv_scale_param = Mock()
|
||||
kv_scale_param.weight_loader = Mock()
|
||||
|
||||
params_dict = {
|
||||
"layers.0.self_attn.kv_scale": kv_scale_param,
|
||||
}
|
||||
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
|
||||
|
||||
# Load_weights logic for KV cache scales
|
||||
name = "layers.0.self_attn.k_proj.weight"
|
||||
loaded_weight_tensor = torch.tensor([1.0, 2.0])
|
||||
|
||||
if mock_quant_config is not None:
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
if scale_name:
|
||||
param = params_dict[scale_name]
|
||||
assert param is kv_scale_param
|
||||
weight_to_load = (
|
||||
loaded_weight_tensor
|
||||
if loaded_weight_tensor.dim() == 0
|
||||
else loaded_weight_tensor[0]
|
||||
)
|
||||
|
||||
assert scale_name == "layers.0.self_attn.kv_scale"
|
||||
assert weight_to_load == loaded_weight_tensor[0]
|
||||
@ -11,7 +11,7 @@ from vllm import TokensPrompt
|
||||
["Qwen/Qwen3-0.6B"],
|
||||
)
|
||||
@torch.inference_mode
|
||||
def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||
def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
|
||||
n_prompt_tokens = [55, 56, 57]
|
||||
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
|
||||
|
||||
@ -21,7 +21,7 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||
enforce_eager=True,
|
||||
runner="pooling",
|
||||
enable_chunked_prefill=False,
|
||||
enable_prefix_caching=False,
|
||||
enable_prefix_caching=True,
|
||||
) as vllm_model:
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||
@ -30,4 +30,29 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||
|
||||
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
||||
assert len(output.prompt_token_ids) == n
|
||||
assert len(output.outputs.data) == n
|
||||
assert output.num_cached_tokens == 0
|
||||
|
||||
# test enable_prefix_caching plus all pooling
|
||||
# we need to skip reading cache at this request by
|
||||
# request.skip_reading_prefix_cache
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||
pooling_task="token_embed",
|
||||
)
|
||||
|
||||
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
||||
assert len(output.prompt_token_ids) == n
|
||||
assert len(output.outputs.data) == n
|
||||
assert output.num_cached_tokens == 0
|
||||
|
||||
# skip_reading_prefix_cache can still write to cache
|
||||
# to accelerate following requests
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||
pooling_task="embed",
|
||||
)
|
||||
|
||||
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
||||
assert len(output.prompt_token_ids) == n
|
||||
assert output.num_cached_tokens > 0
|
||||
|
||||
@ -75,7 +75,7 @@ def test_gemma_multimodal(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg"
|
||||
"url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/red_chair.jpg"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "A fine 19th century piece of furniture."},
|
||||
|
||||
115
tests/models/multimodal/generation/test_multimodal_gguf.py
Normal file
115
tests/models/multimodal/generation/test_multimodal_gguf.py
Normal file
@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pytest import MarkDecorator
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
|
||||
from ....conftest import PromptImageInput, VllmRunner
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
|
||||
class GGUFMMTestConfig(NamedTuple):
|
||||
original_model: str
|
||||
gguf_repo: str
|
||||
gguf_backbone: str
|
||||
gguf_mmproj: str
|
||||
prompt: list[str]
|
||||
mm_data: dict[Literal["images"], PromptImageInput]
|
||||
max_model_len: int = 4096
|
||||
marks: list[MarkDecorator] = []
|
||||
|
||||
@property
|
||||
def gguf_model(self):
|
||||
hf_hub_download(self.gguf_repo, filename=self.gguf_mmproj)
|
||||
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
|
||||
|
||||
|
||||
GEMMA3_CONFIG = GGUFMMTestConfig(
|
||||
original_model="google/gemma-3-4b-it",
|
||||
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
|
||||
gguf_mmproj="mmproj-model-f16-4B.gguf",
|
||||
prompt=["<start_of_image>Describe this image in detail:"],
|
||||
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
|
||||
marks=[pytest.mark.core_model],
|
||||
)
|
||||
|
||||
MODELS_TO_TEST = [GEMMA3_CONFIG]
|
||||
|
||||
|
||||
def run_multimodal_gguf_test(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: GGUFMMTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
):
|
||||
# Run gguf model.
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
model_name=model.gguf_model,
|
||||
enforce_eager=True,
|
||||
tokenizer_name=model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=model.max_model_len,
|
||||
) as gguf_model,
|
||||
):
|
||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
||||
prompts=model.prompt,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
**model.mm_data,
|
||||
)
|
||||
|
||||
# Run unquantized model.
|
||||
with vllm_runner(
|
||||
model_name=model.original_model,
|
||||
enforce_eager=True, # faster tests
|
||||
dtype=dtype,
|
||||
max_model_len=model.max_model_len,
|
||||
) as original_model:
|
||||
original_outputs = original_model.generate_greedy_logprobs(
|
||||
prompts=model.prompt,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
**model.mm_data,
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=original_outputs,
|
||||
outputs_1_lst=gguf_outputs,
|
||||
name_0="original",
|
||||
name_1="gguf",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gguf"),
|
||||
reason="gguf is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
pytest.param(test_config, marks=test_config.marks)
|
||||
for test_config in MODELS_TO_TEST
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: GGUFMMTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
|
||||
@ -14,10 +14,13 @@ from vllm.platforms import current_platform
|
||||
from ...utils import compare_two_settings, multi_gpu_test
|
||||
from ..utils import check_embeddings_close, check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="bitsandbytes quantization not supported on ROCm (CUDA-only kernels)",
|
||||
)
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
on_gfx9(),
|
||||
reason="bitsandbytes not supported on gfx9 (warp size 64 limitation)",
|
||||
)
|
||||
|
||||
models_4bit_to_test = [
|
||||
("facebook/opt-125m", "quantize opt model inflight"),
|
||||
|
||||
@ -78,6 +78,12 @@ DOLPHIN_CONFIG = GGUFTestConfig(
|
||||
gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
|
||||
)
|
||||
|
||||
GEMMA3_CONFIG = GGUFTestConfig(
|
||||
original_model="google/gemma-3-270m-it",
|
||||
gguf_repo="ggml-org/gemma-3-270m-it-qat-GGUF",
|
||||
gguf_filename="gemma-3-270m-it-qat-Q4_0.gguf",
|
||||
)
|
||||
|
||||
MODELS = [
|
||||
# LLAMA_CONFIG, # broken: https://github.com/vllm-project/vllm/issues/19458
|
||||
QWEN2_CONFIG,
|
||||
@ -85,6 +91,7 @@ MODELS = [
|
||||
GPT2_CONFIG,
|
||||
STABLELM_CONFIG,
|
||||
DOLPHIN_CONFIG,
|
||||
GEMMA3_CONFIG,
|
||||
# STARCODER_CONFIG, # broken
|
||||
]
|
||||
|
||||
@ -148,7 +155,7 @@ def check_model_outputs(
|
||||
"model",
|
||||
[pytest.param(test_config, marks=test_config.marks) for test_config in MODELS],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
|
||||
@ -173,6 +173,10 @@ class _HfExamplesInfo:
|
||||
|
||||
_TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AfmoeForCausalLM": _HfExamplesInfo(
|
||||
"arcee-ai/Trinity-Nano",
|
||||
is_available_online=False,
|
||||
),
|
||||
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"),
|
||||
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
|
||||
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),
|
||||
|
||||
@ -16,10 +16,10 @@ from vllm.multimodal.utils import MediaConnector, argsort_mm_positions
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
TEST_IMAGE_ASSETS = [
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
"Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
|
||||
"1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
|
||||
"RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
"Grayscale_8bits_palette_sample_image.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/Grayscale_8bits_palette_sample_image.png",
|
||||
"1280px-Venn_diagram_rgb.svg.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/1280px-Venn_diagram_rgb.svg.png",
|
||||
"RGBA_comp.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/RGBA_comp.png",
|
||||
]
|
||||
|
||||
TEST_VIDEO_URLS = [
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user