Merge branch 'main' into wye-refactor-quant-folder

This commit is contained in:
Wentao Ye 2025-08-15 11:24:28 -04:00 committed by GitHub
commit 7e2fb3c507
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
161 changed files with 4947 additions and 2499 deletions

View File

@ -31,16 +31,6 @@
steps:
##### fast check tests #####
- label: Documentation Build # 2min
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/test_docs"
fast_check: true
no_gpu: True
commands:
- pip install -r ../requirements/docs.txt
# TODO: add `--strict` once warnings in docstrings are fixed
- mkdocs build
- label: Pytorch Nightly Dependency Override Check # 2min
# if this test fails, it means the nightly torch version is not compatible with some
# of the dependencies. Please check the error message and add the package to whitelist
@ -669,6 +659,7 @@ steps:
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
# Fusion
- pytest -v -s tests/compile/test_fusion_all_reduce.py

3
.gitignore vendored
View File

@ -207,3 +207,6 @@ shellcheck*/
# Ignore moe/marlin_moe gen code
csrc/moe/marlin_moe_wna16/kernel_*
# Ignore ep_kernels_workspace folder
ep_kernels_workspace/

View File

@ -249,7 +249,6 @@ set(VLLM_EXT_SRC
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/custom_all_reduce.cu"
"csrc/torch_bindings.cpp")
@ -352,6 +351,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
@ -365,7 +368,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_SRCS}"
CUDA_ARCHS "${MARLIN_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
else()
message(STATUS "Not building Marlin kernels as no compatible archs found"
@ -855,6 +863,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${MOE_WNAA16_MARLIN_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})

View File

@ -1,63 +1,199 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import asyncio
import logging
import os
import aiohttp
from quart import Quart, make_response, request
from quart import Quart, Response, make_response, request
from rate_limiter import RateLimiter
from request_queue import RequestQueue
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def forward_request(url, data):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
def parse_args():
"""parse command line arguments"""
parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server")
# Add args
parser.add_argument(
"--timeout",
type=float,
default=300,
help="Timeout for backend service requests in seconds (default: 300)",
)
parser.add_argument(
"--max-concurrent",
type=int,
default=100,
help="Maximum concurrent requests to backend services (default: 100)",
)
parser.add_argument(
"--queue-size",
type=int,
default=500,
help="Maximum number of requests in the queue (default: 500)",
)
parser.add_argument(
"--rate-limit",
type=int,
default=40,
help="Maximum requests per second (default: 40)",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to run the server on (default: 8000)",
)
parser.add_argument(
"--prefill-url",
type=str,
default="http://localhost:8100/v1/completions",
help="Prefill service endpoint URL",
)
parser.add_argument(
"--decode-url",
type=str,
default="http://localhost:8200/v1/completions",
help="Decode service endpoint URL",
)
return parser.parse_args()
def main():
"""parse command line arguments"""
args = parse_args()
# Initialize configuration using command line parameters
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
MAX_CONCURRENT_REQUESTS = args.max_concurrent
REQUEST_QUEUE_SIZE = args.queue_size
RATE_LIMIT = args.rate_limit
PREFILL_SERVICE_URL = args.prefill_url
DECODE_SERVICE_URL = args.decode_url
PORT = args.port
app = Quart(__name__)
# Initialize the rate limiter and request queue
rate_limiter = RateLimiter(RATE_LIMIT)
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
# Attach the configuration object to the application instance
app.config.update(
{
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
"rate_limiter": rate_limiter,
"request_queue": request_queue,
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
}
)
# Start queue processing on app startup
@app.before_serving
async def startup():
"""Start request processing task when app starts serving"""
asyncio.create_task(request_queue.process())
async def forward_request(url, data):
"""Forward request to backend service with rate limiting and error handling"""
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
# if response.headers.get('Transfer-Encoding') == 'chunked':
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
# finish prefill
async for _ in forward_request(
"http://localhost:8100/v1/completions", prefill_request
# Use rate limiter as context manager
async with (
rate_limiter,
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
):
continue
try:
async with session.post(
url=url, json=data, headers=headers
) as response:
if response.status == 200:
# Stream response chunks
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
# Handle backend service errors
error_text = await response.text()
logger.error(
"Backend service error: %s - %s",
response.status,
error_text,
)
yield b'{"error": "Backend service error"}'
except aiohttp.ClientError as e:
# Handle connection errors
logger.error("Connection error to %s: %s", url, str(e))
yield b'{"error": "Service unavailable"}'
except asyncio.TimeoutError:
# Handle timeout errors
logger.error("Timeout connecting to %s", url)
yield b'{"error": "Service timeout"}'
# return decode
generator = forward_request(
"http://localhost:8200/v1/completions", original_request_data
)
response = await make_response(generator)
response.timeout = None
async def process_request():
"""Process a single request through prefill and decode stages"""
try:
original_request_data = await request.get_json()
return response
# Create prefill request (max_tokens=1)
prefill_request = original_request_data.copy()
prefill_request["max_tokens"] = 1
except Exception as e:
import sys
import traceback
# Execute prefill stage
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request):
continue
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
# Execute decode stage and stream response
generator = forward_request(DECODE_SERVICE_URL, original_request_data)
response = await make_response(generator)
response.timeout = None # Disable timeout for streaming response
return response
except Exception:
logger.exception("Error processing request")
return Response(
response=b'{"error": "Internal server error"}',
status=500,
content_type="application/json",
)
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
"""Handle incoming API requests with concurrency and rate limiting"""
# Create task for request processing
task = asyncio.create_task(process_request())
# Enqueue request or reject if queue is full
if not await request_queue.enqueue(task):
return Response(
response=b'{"error": "Server busy, try again later"}',
status=503,
content_type="application/json",
)
try:
# Return the response from the processing task
return await task
except asyncio.CancelledError:
# Handle task cancellation (timeout or queue full)
logger.warning("Request cancelled due to timeout or queue full")
return Response(
response=b'{"error": "Request cancelled"}',
status=503,
content_type="application/json",
)
# Start the Quart server with host can be set to 0.0.0.0
app.run(port=PORT)
if __name__ == "__main__":
app.run(port=8000)
main()

View File

@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
class RateLimiter:
"""Token bucket rate limiter implementation"""
def __init__(self, rate_limit):
self.rate_limit = rate_limit # Requests per second
self.num_available_tokens = rate_limit # Available tokens
self.last_refill = time.monotonic() # Last token refill time
self.lock = asyncio.Lock() # Synchronization lock
async def acquire(self):
"""Acquire a token from the rate limiter"""
while True:
async with self.lock:
current_time = time.monotonic()
elapsed = current_time - self.last_refill
# Refill num_available_tokens if more than 1 second has passed
if elapsed > 1.0:
self.num_available_tokens = self.rate_limit
self.last_refill = current_time
# Check if num_available_tokens are available
if self.num_available_tokens > 0:
self.num_available_tokens -= 1
return True
# Calculate wait time if no num_available_tokens available
wait_time = 1.0 - elapsed
await asyncio.sleep(wait_time)
async def __aenter__(self):
"""Enter async context manager - acquire token"""
await self.acquire()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit async context manager - no cleanup needed"""
pass

View File

@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections import deque
class RequestQueue:
"""Request queue manager with concurrency control"""
def __init__(self, max_concurrent, max_queue_size):
# Maximum concurrent requests
self.max_concurrent = max_concurrent
self.max_queue_size = max_queue_size # Maximum queue size
# Concurrency control
self.semaphore = asyncio.Semaphore(max_concurrent)
self.queue = deque() # Request queue
self.queue_size = 0 # Current queue size
self.lock = asyncio.Lock() # Sync queue Lock
async def enqueue(self, task):
"""Add a request task to the queue"""
async with self.lock:
if self.queue_size >= self.max_queue_size:
return False
self.queue.append(task)
self.queue_size += 1
return True
async def process(self):
"""Process queued requests using semaphore for concurrency control"""
while True:
if self.queue:
async with self.semaphore, self.lock:
task = self.queue.popleft()
self.queue_size -= 1
await task
await asyncio.sleep(0.01) # Yield control to event loop

View File

@ -236,6 +236,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
a=bt.a,
c=None,
b_q_weight=w_q,
b_bias=None,
b_scales=w_s,
global_scale=None,
b_zeros=w_zp,

View File

@ -321,6 +321,8 @@ static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE8M0fnu =
ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);

View File

@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME {
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
@ -77,6 +78,7 @@ def generate_new_kernels():
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
@ -89,9 +91,22 @@ def generate_new_kernels():
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,

View File

@ -7,23 +7,25 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the

View File

@ -280,6 +280,7 @@ __device__ inline void wait_negative_and_add(int* lock) {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
@ -299,6 +300,7 @@ __global__ void Marlin(
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ b_bias_ptr,
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
@ -318,8 +320,9 @@ __global__ void Marlin(
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce
bool has_bias,
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
@ -342,12 +345,23 @@ __global__ void Marlin(
extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
if constexpr (w_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
static_assert(s_type == vllm::kFloat16);
}
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
w_type == vllm::kFE4M3fn ||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
@ -365,6 +379,7 @@ __global__ void Marlin(
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
const int b_bias_expert_stride = prob_n / 8;
// parallel: num valid moe blocks
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
@ -475,7 +490,7 @@ __global__ void Marlin(
for (int i = 0; i < 4; i++) {
int idx = tid4 * 4 + i;
idx = idx < block_num_valid_tokens ? idx : 0;
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
sh_block_topk_weights[idx] = __hmul2(
global_scale, Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]])));
@ -513,7 +528,7 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id];
}
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = scale2_ptr[expert_id];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
}
@ -526,6 +541,9 @@ __global__ void Marlin(
if constexpr (has_act_order) {
g_idx += (expert_id - old_expert_id) * prob_k;
}
if (has_bias) {
b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride;
}
read_moe_block_data(block_id);
};
@ -721,7 +739,7 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
@ -734,6 +752,18 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
int bias_sh_rd;
if constexpr (m_block_size_8) {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
} else {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
}
int bias_sh_wr = threadIdx.x;
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
@ -793,7 +823,19 @@ __global__ void Marlin(
constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh_new;
int4* sh_red = sh_new;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_min =
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_max =
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
constexpr int sh_b_red_bias_size =
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
? sh_size_b_red_max
: (sh_size_b_red_min + sh_bias_size);
int4* sh_bias = sh_new + sh_size_b_red_min;
int4* sh_g_idx = sh_new + sh_b_red_bias_size;
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
@ -803,9 +845,9 @@ __global__ void Marlin(
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
constexpr int shm_size_used =
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int shm_size_used = moe_block_size +
stages * (g_idx_stage + zp_sh_stage) +
sh_s_size + sh_b_red_bias_size;
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
@ -816,7 +858,8 @@ __global__ void Marlin(
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS frag_s[2][4]; // No act-order
FragS frag_bias[2][4];
FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
@ -1065,10 +1108,15 @@ __global__ void Marlin(
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
k % 2];
}
}
}
@ -1281,9 +1329,9 @@ __global__ void Marlin(
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0,
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>(
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
@ -1566,7 +1614,7 @@ __global__ void Marlin(
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
auto write_result = [&](bool last) {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
@ -1592,7 +1640,7 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) {
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
@ -1601,14 +1649,27 @@ __global__ void Marlin(
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]);
scalar_t2 tmp_scale = s[0];
if constexpr (m_block_size_8) {
tmp_scale = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
}
res = __hmul2(res, tmp_scale);
}
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if (!mul_topk_weights) {
res = __hmul2(res, global_scale);
}
}
if (has_bias && last) {
scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {
tmp_bias = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
}
res = __hadd2(res, tmp_bias);
}
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
@ -1626,19 +1687,25 @@ __global__ void Marlin(
if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
frag_s[j / 2][2 * (j % 2) + 0]);
frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
frag_s[j / 2][2 * (j % 2) + 1]);
frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
} else {
int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
}
}
c_sh_wr += 16 * (4 * c_sh_stride);
@ -1805,6 +1872,14 @@ __global__ void Marlin(
}
thread_block_reduce();
if (has_bias && last) {
__syncthreads();
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
threadIdx.x < 16 * thread_n_blocks / 8);
cp_async_fence();
}
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
@ -1867,11 +1942,20 @@ __global__ void Marlin(
}
barrier_release(&locks[locks_off], last);
}
if (has_bias && last) {
cp_async_wait<0>();
__syncthreads();
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
__syncthreads();
}
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
wait_negative_and_add(&locks[locks_off]);
if (last || use_atomic_add)
// only the last block in a slice actually writes the result
write_result();
write_result(last);
int old_slice_row = slice_row;
slice_row = 0;
slice_col_par++;
@ -1904,6 +1988,7 @@ __global__ void Marlin(
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Update slice k/n for scales loading
if constexpr (has_act_order) {
slice_k_start = tb_k * slice_row;

View File

@ -51,8 +51,9 @@ __global__ void permute_cols_kernel(
} // namespace marlin
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
@ -212,7 +213,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
int tb_m = thread_m_blocks * 16;
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
@ -220,6 +221,11 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
@ -234,8 +240,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
sh_zp_size = sh_s_size / 2;
}
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
sh_zp_size + sh_g_idx_size + sh_block_meta_size;
int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size +
sh_g_idx_size + sh_block_meta_size;
return total_size;
}
@ -270,20 +276,25 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem;
return cache_size + 512 <= max_shared_mem;
}
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
@ -335,31 +346,45 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
@ -408,12 +433,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
NVFP4_GET_IF(vllm::kFE2M1f)
FP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel;
}
@ -482,16 +512,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
}
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
void* sorted_token_ids, void* expert_ids,
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm,
void* a_tmp, void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
int prob_m, int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k, int thread_n,
int sms, bool use_atomic_add, bool use_fp32_reduce,
vllm::ScalarType const& q_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
@ -538,6 +568,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp;
@ -648,10 +679,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem);
// clang-format on
}
@ -659,7 +690,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
@ -766,7 +798,6 @@ torch::Tensor moe_wna16_marlin_gemm(
num_groups = b_scales.size(1);
torch::Tensor g_idx, perm, a_tmp;
;
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
g_idx = g_idx_or_none.value();
perm = perm_or_none.value();
@ -815,12 +846,24 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
"global_scale can only be used for float4_e2m1f.");
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
"the global_scale parameter must be passed for float4_e2m1f.");
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
"the global_scale parameter must be passed for nvfp4 format.");
}
bool has_bias = b_bias_or_none.has_value();
torch::Tensor b_bias;
if (has_bias) {
b_bias = b_bias_or_none.value();
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n");
TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1");
} else {
b_bias = torch::empty({0}, options);
}
torch::Tensor b_zeros;
@ -832,7 +875,6 @@ torch::Tensor moe_wna16_marlin_gemm(
b_zeros = torch::empty({0}, options);
}
bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
@ -890,41 +932,58 @@ torch::Tensor moe_wna16_marlin_gemm(
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
is_k_full, has_zp, num_groups, group_size, dev,
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false,
"moe_wna16_marlin_gemm only supports bfloat16 and float16");

View File

@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
"Tensor! b_q_weight, Tensor? b_bias_or_none,"
"Tensor! b_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"

View File

@ -145,22 +145,6 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
int64_t block_size, torch::Tensor& input_tokens,
torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions,
torch::Tensor& seq_lens,
torch::Tensor& slot_mapping,
torch::Tensor& block_tables);
void advance_step_flashinfer(
int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,

View File

@ -1,336 +0,0 @@
/*
* The goal of this GPU kernel is to advance input tensors on the GPU directly
* PR: https://github.com/vllm-project/vllm/pull/6338
* Current restrictions:
* 1. Specialized for DraftModelRunner
* 2. Supports flash_attn only
*/
#include "advance_step.cuh"
namespace prepare_inputs {
//
template <int const num_threads>
__global__ void advance_step_flashattn_kernel(
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
long const* sampled_token_ids_ptr, long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
int64_t const block_tables_stride) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}
int num_query_blocks = div_ceil(num_queries, num_threads);
if (blockIdx.x >= num_query_blocks) {
return;
}
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
if (cur_query_id >= num_queries) {
return;
}
// Update input_tokens
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
int seq_len = seq_lens_ptr[cur_query_id];
int next_seq_len = seq_len + 1;
int next_input_pos = next_seq_len - 1;
// Update seq_lens
seq_lens_ptr[cur_query_id] = next_seq_len;
// Update input_positions
input_positions_ptr[cur_query_id] = next_input_pos;
int const* seq_block_tables_ptr =
block_tables_ptr + block_tables_stride * cur_query_id;
int block_index = next_input_pos / block_size;
int block_offset = next_input_pos % block_size;
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
// Update slot_mapping
slot_mapping_ptr[cur_query_id] = slot_num;
}
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) {
bool size_0_cond = true;
if (size_0 != -1) {
size_0_cond = t.size(0) == size_0;
}
bool size_1_cond = true;
if (size_1 != -1) {
size_1_cond = t.size(1) == size_1;
}
bool is_contiguous = t.is_contiguous();
bool same_type = t.dtype() == type;
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
if (!pass) {
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
" is not as expected: shape = [", size_0, ", ", size_1,
"], type = ", type);
}
}
/// each thread processes a block per query
__global__ void advance_step_flashinfer_kernel(
int num_threads, int num_seqs, int num_queries, int block_size,
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
int const* block_tables_ptr, int64_t const block_tables_stride,
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}
int num_query_blocks = div_ceil(num_queries, num_threads);
if (blockIdx.x < num_query_blocks) {
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
if (cur_query_id < num_queries) {
// Update input_tokens
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
int seq_len = seq_lens_ptr[cur_query_id];
int next_seq_len = seq_len + 1;
int next_input_pos = next_seq_len - 1;
// Update seq_lens
seq_lens_ptr[cur_query_id] = next_seq_len;
// Update input_positions
input_positions_ptr[cur_query_id] = next_input_pos;
int const* seq_block_tables_ptr =
block_tables_ptr + block_tables_stride * cur_query_id;
int block_index = next_input_pos / block_size;
int block_offset = next_input_pos % block_size;
// Update paged_kv_last_page_len
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
int slot_num =
seq_block_tables_ptr[block_index] * block_size + block_offset;
// Update slot_mapping
slot_mapping_ptr[cur_query_id] = slot_num;
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
}
}
}
__global__ void advance_step_flashinfer_indptr_kernel(
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;
// Update paged_kv_indptr
if (idx == 0) {
paged_kv_indptr_ptr[idx] = 0;
}
if (idx < num_queries) {
int sum = 0;
for (int i = 0; i <= idx; ++i) {
sum += block_table_bound_ptr[i];
}
paged_kv_indptr_ptr[idx + 1] = sum;
}
}
__global__ void advance_step_flashinfer_indices_kernel(
int num_seqs, int num_queries, int const* block_tables_ptr,
int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr,
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
// note: max_num_blocks_per_seq = block_tables.stride(0)
int tid = blockIdx.x * blockDim.x + threadIdx.x;
// when cuda graphs are enabled, paged_kv_indptr tensor
// has to be updated for the padded queries
// tid represents a query# for paged_kv_indptr tensor
if (num_queries < tid && tid <= num_seqs) {
paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries];
}
// each thread processes a block_ptr in block_tables
// block_tables shape: [num_queries, max_num_blocks_per_seq]
// paged_kv_indices is flattened block_tables.
for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq);
idx += (gridDim.x * blockDim.x)) {
// block_tables-row = paged_kv_indptr[queryNum]
int queryNum = idx / max_num_blocks_per_seq;
int col = idx % max_num_blocks_per_seq;
if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) {
int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col;
int block_tables_idx = queryNum * max_num_blocks_per_seq + col;
paged_kv_indices_ptr[indices_arr_idx] =
block_tables_ptr[block_tables_idx];
}
}
}
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
torch::Tensor& seq_lens, // type: int
torch::Tensor& slot_mapping, // type: long
torch::Tensor& block_tables) { // type: int
if (logging) {
printf("advance_step_flashattn:\n");
printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size);
}
// Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
at::kLong);
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
int dev = sampled_token_ids.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
advance_step_flashattn_kernel<max_threads>
<<<blocks, max_threads, 0, stream>>>(
num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
reinterpret_cast<long*>(input_positions.data_ptr()),
reinterpret_cast<int*>(seq_lens.data_ptr()),
reinterpret_cast<long*>(slot_mapping.data_ptr()),
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0));
}
void advance_step_flashinfer(
int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
torch::Tensor& seq_lens, // type: int
torch::Tensor& slot_mapping, // type: long
torch::Tensor& block_tables, // type: int
torch::Tensor& paged_kv_indices, // type: int
torch::Tensor& paged_kv_indptr, // type: int
torch::Tensor& paged_kv_last_page_len, // type: int
torch::Tensor& block_table_bound) { // type: int
if (logging) {
printf("advance_step_flashinfer:\n");
printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size);
printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0));
}
// Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
// at::kLong);
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
at::kInt);
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
int dev = sampled_token_ids.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
int threads;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
TORCH_CHECK((blocks * threads > num_queries),
"multi-step: not enough threads to map to num_queries = ",
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
" blocks = ", blocks, " max_threads = ", threads);
if (logging) {
printf("launching kernels with %d blocks and %d threads\n", blocks,
threads);
}
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
reinterpret_cast<long*>(input_positions.data_ptr()),
reinterpret_cast<int*>(seq_lens.data_ptr()),
reinterpret_cast<long*>(slot_mapping.data_ptr()),
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0),
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
reinterpret_cast<int*>(block_table_bound.data_ptr()));
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries,
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
reinterpret_cast<int*>(block_table_bound.data_ptr()));
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
num_seqs, num_queries,
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0),
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
reinterpret_cast<int*>(block_table_bound.data_ptr()));
}
} // namespace prepare_inputs
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
int64_t block_size, torch::Tensor& input_tokens,
torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions,
torch::Tensor& seq_lens,
torch::Tensor& slot_mapping,
torch::Tensor& block_tables) {
prepare_inputs::advance_step_flashattn(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables);
}
void advance_step_flashinfer(
int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
prepare_inputs::advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
}

View File

@ -1,19 +0,0 @@
#pragma once
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
namespace prepare_inputs {
static constexpr int max_threads = 256;
static constexpr bool logging = false;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
} // namespace prepare_inputs

View File

@ -470,11 +470,12 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <typename scalar_t2>
template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
template <>
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
__device__ inline void dequant_fp8_scales<half2, vllm::kFE4M3fn.id()>(
int q, half2* frag_b) {
int Out1 = (q & 0xFF00FF00) >> 1;
;
q <<= 8;
@ -486,8 +487,8 @@ __device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
};
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
nv_bfloat162* frag_b) {
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE4M3fn.id()>(
int q, nv_bfloat162* frag_b) {
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
@ -502,6 +503,20 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
int q, nv_bfloat162* frag_b) {
// In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16,
// but we assume that such a extreme value would not occur in real models.
int Out1 = (q & 0xFF00FF00) >> 1;
q <<= 7;
int Out2 = q & 0x7F807F80;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME

View File

@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME {
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
@ -78,7 +79,8 @@ def generate_new_kernels():
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
@ -97,10 +99,23 @@ def generate_new_kernels():
# 4bit quantization and fp16
is_zp_float_list.append(True)
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for is_zp_float in is_zp_float_list:
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,

View File

@ -48,7 +48,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
@ -187,7 +188,12 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int tb_m = thread_m_blocks * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8);
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
@ -202,8 +208,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
sh_zp_size = sh_s_size / 2;
}
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
sh_zp_size + sh_g_idx_size;
int total_size =
tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size;
return total_size;
}
@ -237,20 +243,25 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem;
return cache_size + 512 <= max_shared_mem;
}
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
@ -315,22 +326,39 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 4, 8, 128)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
@ -384,7 +412,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
FP4_GET_IF(vllm::kFE2M1f)
NVFP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
@ -396,6 +424,11 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
}
FZP_GET_IF(vllm::kU4)
}
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel;
}
@ -453,12 +486,12 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
}
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
int prob_m, int prob_n, int prob_k, int lda, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k_init,
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm,
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda,
void* workspace, vllm::ScalarType const& q_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k_init,
int thread_n_init, int sms, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) {
if (has_zp) {
@ -503,6 +536,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp;
@ -623,8 +657,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups,
prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr,
g_idx_ptr, num_groups,
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
use_fp32_reduce, max_shared_mem_new);
// clang-format on
@ -638,7 +673,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
@ -785,12 +821,24 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
"global_scale can only be used for float4_e2m1f.");
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
"the global_scale parameter must be passed for float4_e2m1f.");
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
"the global_scale parameter must be passed for nvfp4 format.");
}
bool has_bias = b_bias_or_none.has_value();
torch::Tensor b_bias;
if (has_bias) {
b_bias = b_bias_or_none.value();
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
TORCH_CHECK(b_bias.size(0) == size_n, "b_bias.size(0) != size_n");
TORCH_CHECK(b_bias.stride(0) == 1, "b_bias.stride(0) != 1");
} else {
b_bias = torch::empty({0}, options);
}
torch::Tensor b_zeros;
@ -857,34 +905,50 @@ torch::Tensor gptq_marlin_gemm(
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
marlin::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0),
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order,
is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {

View File

@ -10,15 +10,18 @@
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the

View File

@ -39,6 +39,7 @@ namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
@ -271,6 +272,7 @@ __device__ inline void wait_negative_and_add(int* lock) {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
@ -290,6 +292,7 @@ __global__ void Marlin(
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ b_bias_ptr,
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
@ -297,12 +300,13 @@ __global__ void Marlin(
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int lda, // A.stride(0), equal to prob_k is A is contiguous
int* locks, // extra global storage for barrier synchronization
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int lda, // A.stride(0), equal to prob_k is A is contiguous
int* locks, // extra global storage for barrier synchronization
bool has_bias,
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) {
@ -326,18 +330,29 @@ __global__ void Marlin(
using FragZP = typename ScalarType<scalar_t>::FragZP;
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
if constexpr (w_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
static_assert(s_type == vllm::kFloat16);
}
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
w_type == vllm::kFE4M3fn ||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
scalar_t2 global_scale;
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
// NVFP4 format requires global scale
uint16_t val = scale2_ptr[0];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
}
@ -589,7 +604,7 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
@ -602,6 +617,18 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
int bias_sh_rd;
if constexpr (m_block_size_8) {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
} else {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
}
int bias_sh_wr = threadIdx.x;
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
@ -670,7 +697,19 @@ __global__ void Marlin(
constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh;
int4* sh_red = sh;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_min =
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_max =
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
constexpr int sh_b_red_bias_size =
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
? sh_size_b_red_max
: (sh_size_b_red_min + sh_bias_size);
int4* sh_bias = sh + sh_size_b_red_min;
int4* sh_g_idx = sh + sh_b_red_bias_size;
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
@ -680,15 +719,13 @@ __global__ void Marlin(
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
// constexpr int shm_size_used =
// stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
// (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS frag_s[2][4]; // No act-order
FragS frag_bias[2][4];
FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
@ -923,10 +960,15 @@ __global__ void Marlin(
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
k % 2];
}
}
}
@ -1139,9 +1181,9 @@ __global__ void Marlin(
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0,
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>(
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
@ -1411,7 +1453,7 @@ __global__ void Marlin(
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
auto write_result = [&](bool last) {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
@ -1438,7 +1480,7 @@ __global__ void Marlin(
int c_gl_wr_end = c_gl_stride * prob_m;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) {
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
@ -1447,12 +1489,25 @@ __global__ void Marlin(
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]);
scalar_t2 tmp_scale = s[0];
if constexpr (m_block_size_8) {
tmp_scale = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
}
res = __hmul2(res, tmp_scale);
}
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
res = __hmul2(res, global_scale);
}
if (has_bias && last) {
scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {
tmp_bias = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
}
res = __hadd2(res, tmp_bias);
}
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
@ -1470,19 +1525,25 @@ __global__ void Marlin(
if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
frag_s[j / 2][2 * (j % 2) + 0]);
frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
frag_s[j / 2][2 * (j % 2) + 1]);
frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
} else {
int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
frag_bias[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
frag_bias[j / 2][2 * (j % 2) + 1]);
}
}
c_sh_wr += 16 * (4 * c_sh_stride);
@ -1622,6 +1683,14 @@ __global__ void Marlin(
}
thread_block_reduce();
if (has_bias && last) {
__syncthreads();
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
threadIdx.x < 16 * thread_n_blocks / 8);
cp_async_fence();
}
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
@ -1684,11 +1753,20 @@ __global__ void Marlin(
}
barrier_release(&locks[locks_off], last);
}
if (has_bias && last) {
cp_async_wait<0>();
__syncthreads();
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
__syncthreads();
}
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
wait_negative_and_add(&locks[locks_off]);
if (last || use_atomic_add)
// only the last block in a slice actually writes the result
write_result();
write_result(last);
slice_row = 0;
slice_col_par++;
slice_col++;
@ -1706,6 +1784,7 @@ __global__ void Marlin(
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Update slice k/n for scales loading
if constexpr (has_act_order) {
slice_k_start = tb_k * slice_row;

View File

@ -142,25 +142,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
// prepare_inputs advance_step
ops.def(
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
"Tensor! input_tokens, Tensor sampled_token_ids, "
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
"Tensor block_tables) -> ()");
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
ops.def(
"advance_step_flashinfer("
" int num_seqs, int num_queries, int block_size,"
" Tensor! input_tokens, Tensor sampled_token_ids,"
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
" Tensor block_tables, Tensor! paged_kv_indices,"
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
" Tensor! block_table_bounds"
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
@ -326,6 +307,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none,"
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "

View File

@ -497,14 +497,11 @@ ENV HF_HUB_ENABLE_HF_TRANSFER 1
# Copy in the v1 package for testing (it isn't distributed yet)
COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
# doc requires source code
# we hide them inside `test_docs/` , so that this source code
# Source code is used in the `python_only_compile.sh` test
# We hide it inside `src/` so that this source code
# will not be imported by other tests
RUN mkdir test_docs
RUN mv docs test_docs/
RUN cp -r examples test_docs/
RUN mv vllm test_docs/
RUN mv mkdocs.yaml test_docs/
RUN mkdir src
RUN mv vllm src/vllm
#################### TEST IMAGE ####################
#################### OPENAI API SERVER ####################

View File

@ -2,4 +2,5 @@ Loading Model weights with fastsafetensors
===================================================================
Using fastsafetensors library enables loading model weights to GPU memory by leveraging GPU direct storage. See [their GitHub repository](https://github.com/foundation-model-stack/fastsafetensors) for more details.
For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true``
To enable this feature, use the ``--load-format fastsafetensors`` command-line argument

View File

@ -35,6 +35,7 @@ You can check if this is happening by trying the old defaults with `--generation
If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue:
- `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging.
- `export VLLM_LOG_STATS_INTERVAL=1.` to get log statistics more frequently for tracking running queue, waiting queue and cache hit states.
- `export CUDA_LAUNCH_BLOCKING=1` to identify which CUDA kernel is causing the problem.
- `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL.
- `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time.

View File

@ -3,7 +3,8 @@
import contextlib
import os
import weakref
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional
import pytest
@ -32,27 +33,130 @@ def temporary_environ(env_vars):
os.environ[k] = v
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
},
specific_gpu_arch=(10, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
comp_config={
"cudagraph_mode": "FULL",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
comp_config={
"cudagraph_mode": "FULL",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
test_params_full_cudagraph = []
# deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends = ["FlashMLA", "CutlassMLA"]
for mla_backend in MLA_backends:
test_params_full_cudagraph.append(
pytest.param(
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
# Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [
backend_configs[c] for c in backend_configs if c not in MLA_backends
]
for backend_config in other_backend_configs:
test_params_full_cudagraph.append(
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
@pytest.fixture(scope="class")
def llm_pair(request):
model = request.param
model, backend_config = request.param
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
# Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
!= current_platform.get_device_capability():
if backend_config.specific_gpu_arch == (9, 0):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
elif backend_config.specific_gpu_arch == (10, 0):
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
env_vars = {
"VLLM_USE_V1": "1",
# Force native sampler to avoid potential nondeterminism in FlashInfer
# when per-request generators are not used in V1.
"VLLM_USE_FLASHINFER_SAMPLER": "0",
**backend_config.env_vars,
}
with temporary_environ(env_vars):
full = LLM(
model=model,
gpu_memory_utilization=0.45,
gpu_memory_utilization=0.43,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(full_cuda_graph=True),
max_num_seqs=128,
compilation_config=\
CompilationConfig(**backend_config.comp_config),
generation_config="vllm",
seed=42,
)
piecewise = LLM(
model=model,
gpu_memory_utilization=0.45,
gpu_memory_utilization=0.43,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(),
max_num_seqs=128,
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
generation_config="vllm",
seed=42,
)
# PyTest caches the fixture values so we use weakref.proxy to enable GC
@ -66,16 +170,7 @@ def llm_pair(request):
)
@pytest.mark.parametrize(
"llm_pair",
[
# Model names for the llm_pair fixture
"deepseek-ai/DeepSeek-V2-Lite",
"Qwen/Qwen2-1.5B-Instruct"
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
class TestFullCUDAGraph:
"""
Use a class such that an llm pair is constructed once for all
@ -104,12 +199,14 @@ class TestFullCUDAGraph:
full cudagraph compilation works for padded cases too.
"""
piecewise_llm, full_cudagraph_llm = llm_pair
full_cudagraph_llm, piecewise_llm = llm_pair
prompts = ["Hello, my name is"] * batch_size
prompts = ["the quick brown fox"] * batch_size
# Use purely greedy decoding to avoid top-p truncation sensitivity
# that can amplify tiny numeric differences across runtimes.
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
top_p=1.0)
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
@ -117,42 +214,16 @@ class TestFullCUDAGraph:
# Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
@pytest.mark.parametrize(
"model, supported",
[
("Qwen/Qwen2-1.5B-Instruct", True),
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
("deepseek-ai/DeepSeek-V2-Lite", False),
])
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
def test_lower_max_num_seqs(model, supported):
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(RuntimeError))
llm = LLM(model=model,
max_num_seqs=256,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(
full_cuda_graph=True,
cudagraph_capture_sizes=[64, 256, 512]))
llm.generate(["Hello, my name is"] * 10)
assert piecewise_res.outputs[0].text.lower() == \
full_res.outputs[0].text.lower()
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION":
"2" #FA2 not supported with full_cuda_graph
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
# Flex_Attention is not supported with full cuda graph
}), pytest.raises(RuntimeError):
LLM(model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(full_cuda_graph=True))
compilation_config=CompilationConfig(cudagraph_mode="FULL"))

View File

@ -11,10 +11,10 @@ from torch.library import Library
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op
global_counter = 0
@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor):
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
), set_forward_context({}, vllm_config=vllm_config):
), set_forward_context(None,
vllm_config=vllm_config): # background context
# warm up with background context
model(inputs)
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
# capturing/replaying should under context of cudagraph dispatching
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
model(torch.randn(2).cuda())
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=1, )):
model(torch.randn(1).cuda())
input = torch.zeros(2).cuda()
global global_counter
global_counter = 0
output = model(input)
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))

View File

@ -18,9 +18,9 @@ from torch.library import Library
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
@ -276,9 +276,11 @@ def run_model(llama_config,
)
if split_attn:
compilation_config.splitting_ops = ["silly.attention"]
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
cudagraph_runtime_mode = CUDAGraphMode.NONE
vllm_config = VllmConfig(compilation_config=compilation_config,
additional_config=llama_config)
@ -287,17 +289,37 @@ def run_model(llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()
with set_forward_context({}, vllm_config=vllm_config):
with set_forward_context({},
vllm_config=vllm_config): # background context
B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()
# warmup for the model with cudagraph_mode NONE
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])
# simulate cudagraphs capturing
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
model(input_ids[:2], positions[:2])
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1, )):
model(input_ids[:1], positions[:1])
input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
# simulate cudagraphs replay
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
output = model(input_ids[:2], positions[:2])
output = output.cpu()

View File

@ -9,10 +9,10 @@ import torch
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
from vllm.platforms import current_platform
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
QDTYPES = [None]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check

View File

@ -29,17 +29,14 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
PARTITION_SIZE_ROCM = 256
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [
torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
DTYPES = [torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
HEAD_SIZES = [32, 80, 128, 256]
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]

View File

@ -11,11 +11,11 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float]
DTYPES = [torch.bfloat16, torch.float]
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 120, 256]
HEAD_SIZES = [64, 80, 256]
BLOCK_SIZES = [8, 16, 32]
CACHE_LAYOUTS = ["NHD", "HND"]

View File

@ -12,14 +12,16 @@ from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_with_kvcache,
is_fa_version_supported)
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
SOFT_CAPS = [None, 50.0]
SLIDING_WINDOWS = [None, 256]
def ref_paged_attn(
@ -83,9 +85,9 @@ def ref_paged_attn(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("fa_version", [2, 3])
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode()
@ -198,9 +200,9 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3])
@pytest.mark.parametrize("q_dtype", QDTYPES)

View File

@ -9,11 +9,13 @@ import torch
from vllm.platforms import current_platform
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
NUM_HEADS = [(32, 8), (6, 1)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
DTYPES = [torch.bfloat16]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
SOFT_CAPS = [None, 30.0]
SLIDING_WINDOWS = [None, 64]
def ref_paged_attn(
@ -76,8 +78,8 @@ def ref_paged_attn(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 64])
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@torch.inference_mode
def test_flashinfer_decode_with_paged_kv(
kv_lens: list[int],
@ -173,8 +175,8 @@ def test_flashinfer_decode_with_paged_kv(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 64])
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@torch.inference_mode
def test_flashinfer_prefill_with_paged_kv(
seq_lens: list[tuple[int, int]],
@ -278,11 +280,11 @@ def test_flashinfer_prefill_with_paged_kv(
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
def test_flashinfer_prefill_with_paged_fp8_kv(
seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
head_size: int, dtype: torch.dtype, block_size: int,
@ -385,11 +387,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@pytest.mark.skip(reason="TODO: fix the accuracy issue")
@torch.inference_mode
def test_flashinfer_decode_with_paged_fp8_kv(
kv_lens: list[int],
@ -399,7 +402,6 @@ def test_flashinfer_decode_with_paged_fp8_kv(
block_size: int,
soft_cap: Optional[float],
) -> None:
pytest.skip("TODO: fix the accuracy issue")
# test doesn't work for num_heads = (16,16)
torch.set_default_device("cuda")
current_platform.seed_everything(0)

View File

@ -20,11 +20,11 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
MAX_Q_LEN = 1024
MAX_KV_LEN = 4096
BATCH_SIZES = [4, 12]
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
NUM_HEADS = [(16, 16), (40, 8)]
HEAD_SIZES = [128]
BLOCK_SIZES = [16, 32]
BLOCK_SIZES = [16]
KV_LAYOUTS = ["HND"]
DTYPES = [torch.float16, torch.bfloat16]
DTYPES = [torch.bfloat16]
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
SOFT_CAPS = [None, 50.0]

View File

@ -19,13 +19,13 @@ from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
HEAD_SIZES = [128, 96, 24]
NUM_QUERIES_PER_KV = [1, 64]
HEAD_SIZES = [24, 128]
DTYPES = [torch.float16]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
SLIDING_WINDOW = [0, 16, 2048]
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
OPS = [chunked_prefill_paged_decode, context_attention_fwd]

View File

@ -9,11 +9,11 @@ import torch
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.platforms import current_platform
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
BLOCK_SIZES = [16]
DTYPES = [torch.float16, torch.bfloat16]
DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
None, torch.float8_e4m3fnuz
]
@ -85,7 +85,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("soft_cap", [None, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode()

View File

@ -9,7 +9,7 @@ from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba_attn import (
from vllm.v1.attention.backends.mamba2_attn import (
_query_start_loc_to_chunk_indices_offsets)
# Added by the IBM Team, 2024

View File

@ -89,14 +89,11 @@ class BatchedMMTensors:
return BatchedMMTensors(A, B, C, num_expert_tokens)
@pytest.mark.parametrize("num_experts", [8, 16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 1024])
@pytest.mark.parametrize(
"dtype",
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("num_experts", [8, 32])
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,

View File

@ -113,8 +113,7 @@ def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
rtol=0)
@pytest.mark.parametrize(
"num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317])
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317])
@pytest.mark.parametrize("num_topk", [2, 6, 8])
@pytest.mark.parametrize("num_experts", [64])
@pytest.mark.parametrize("ep_size", [1, 2, 4])
@ -126,7 +125,7 @@ def test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
ep_size, topk_ids_dtype)
@pytest.mark.parametrize("numel", list(range(1, 8192, 11)))
@pytest.mark.parametrize("numel", list(range(1, 8192, 111)))
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("ep_size", [2])
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])

View File

@ -24,8 +24,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_permute_bias)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
rand_marlin_weight_fp4_like)
rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
@ -40,6 +42,24 @@ NUM_EXPERTS = [8, 64, 192]
EP_SIZE = [1, 4]
TOP_KS = [2, 6]
FUSED_MOE_MNK_FACTORS = [
(1, 128, 128),
(1, 2048, 128),
(33, 2048, 128),
(222, 1024, 1024),
(32768, 128, 128),
(32768, 2048, 511),
(40000, 1024, 1024),
]
FUSED_MOE_WN16_MNK_FACTORS = [
(1, 128, 128),
(1, 1024, 1024),
(32, 2048, 128),
(32, 1024, 1024),
(222, 2048, 1024),
]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@ -114,13 +134,11 @@ def run_moe_test(
return baseline_output
@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize("chunk_size", [8192])
def test_fused_moe(
@ -233,13 +251,11 @@ def test_fused_moe(
use_cudagraph=use_cudagraph)
@pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
@ -350,8 +366,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
@ -476,8 +491,11 @@ def marlin_moe_generate_valid_test_cases():
if quant_type == scalar_types.float8_e4m3fn and \
group_size not in [-1, 128]:
return False
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
return False
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32]:
return False
if dtype == torch.float16 and group_size == 32:
return False
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return False
@ -520,31 +538,6 @@ def test_fused_marlin_moe(
torch.cuda.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
if quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
return
# Filter act_order
if act_order:
if quant_type == scalar_types.float8_e4m3fn:
return
if group_size == -1:
return
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
return
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
@ -569,13 +562,19 @@ def test_fused_marlin_moe(
for i in range(w1.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
w_ref1, qweight1, scales1, global_scale1 = \
rand_marlin_weight_fp4_like(w1[i], group_size)
if group_size == 16:
w_ref1, qweight1, scales1, global_scale1 = \
rand_marlin_weight_nvfp4_like(w1[i], group_size)
else:
w_ref1, qweight1, scales1 = \
rand_marlin_weight_mxfp4_like(w1[i], group_size)
global_scale1 = None
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
global_scale1_l.append(global_scale1)
if global_scale1 is not None:
global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
w1[i], group_size)
@ -620,13 +619,19 @@ def test_fused_marlin_moe(
for i in range(w2.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
w_ref2, qweight2, scales2, global_scale2 = \
rand_marlin_weight_fp4_like(w2[i], group_size)
if group_size == 16:
w_ref2, qweight2, scales2, global_scale2 = \
rand_marlin_weight_nvfp4_like(w2[i], group_size)
else:
w_ref2, qweight2, scales2 = \
rand_marlin_weight_mxfp4_like(w2[i], group_size)
global_scale2 = None
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
global_scale2_l.append(global_scale2)
if global_scale2 is not None:
global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
w2[i], group_size)
@ -677,6 +682,8 @@ def test_fused_marlin_moe(
a,
qweight1,
qweight2,
None,
None,
scales1,
scales2,
score,
@ -698,6 +705,119 @@ def test_fused_marlin_moe(
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m):
torch.cuda.manual_seed(0)
e, topk = 32, 4
n, k = 2048, 2048
group_size = 128
act_order = False
is_k_full = True
quant_type = scalar_types.uint4b8
dtype = torch.half
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
b_bias1_l = []
w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
b_bias2_l = []
w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
b_bias2)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
marlin_bias1,
marlin_bias2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=None,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
def test_moe_align_block_size_opcheck():
num_experts = 4
block_size = 4

View File

@ -15,10 +15,10 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from vllm.platforms import current_platform
from vllm.utils import round_up
NUM_TOKENS = [1, 3, 7, 16, 256, 2256, 4096]
NUM_EXPERTS = [32, 160, 256, 257, 512]
NUM_TOKENS = [1, 3, 256, 2256, 4096]
NUM_EXPERTS = [32, 160, 256, 257]
TOP_KS = [1, 2, 16, 32]
BLOCK_SIZES = [32, 64, 128, 256]
BLOCK_SIZES = [32, 128]
current_platform.seed_everything(0)

View File

@ -18,7 +18,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
from vllm.platforms import current_platform
NUM_EXPERTS = [16, 64, 256]
TOP_KS = [2, 4, 6, 8]
TOP_KS = [2, 6, 8]
EP_SIZE = [1, 4, 16]
current_platform.seed_everything(0)
@ -177,11 +177,11 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
return output
@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000])
@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168])
@pytest.mark.parametrize("n_token", [1, 33, 1024, 5000])
@pytest.mark.parametrize("n_hidden", [2048, 7168])
@pytest.mark.parametrize("n_expert", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("align_block_size", [None, 128])
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,

View File

@ -44,6 +44,14 @@ requires_pplx = pytest.mark.skipif(
reason="Requires PPLX kernels",
)
BATCHED_MOE_MNK_FACTORS = [
(1, 128, 128),
(33, 2048, 128),
(64, 128, 2048),
(222, 128, 128),
(222, 2048, 1024),
]
PPLX_COMBOS = [
# TODO: figure out why this fails, seems to be test problem
#(1, 128, 128),
@ -152,9 +160,7 @@ def torch_batched_moe(
return torch_finalize(out, topk_weight, topk_ids)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])

View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
convert_swizzled_to_linear, dequantize_nvfp4_to_dtype)
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm
if not current_platform.has_device_capability(100):
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
DTYPES = [torch.float16, torch.bfloat16]
# m, n, k
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES.extend(PAD_SHAPES)
SEEDS = [42]
CUDA_DEVICES = ["cuda:0"]
def get_ref_results(
a_fp4,
b_fp4,
a_sf,
b_sf,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
device,
):
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert m_k == n_k
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
b_sf,
b_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
return torch.matmul(a_in_dtype, b_in_dtype.t())
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("backend", ["cutlass", "trtllm"])
@pytest.mark.parametrize("autotune", [False, True])
@torch.inference_mode()
def test_flashinfer_nvfp4_gemm(
dtype: torch.dtype,
shape: tuple[int, int, int],
seed: int,
device: str,
backend: str,
autotune: bool,
) -> None:
if backend == "trtllm" and dtype == torch.float16:
pytest.skip(
"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
current_platform.seed_everything(seed)
m, n, packed_k = shape
k = packed_k * 2
block_size = 16
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
# So instead of needing to swizzle for cutlass as in modelopt.py,
# we need to unswizzle for trtllm here.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
# get_ref_results unswizzles the scales internally.
expected_out = get_ref_results(
a_fp4,
b_fp4,
a_scale_interleaved,
b_scale_interleaved,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
device,
)
import flashinfer
if backend == "trtllm":
epilogue_tile_m = 128
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8),
epilogue_tile_m)
b_scale_interleaved = convert_swizzled_to_linear(
b_scale_interleaved, n, k, block_size)
b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a(
b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape(
b_scale_interleaved.shape).view(torch.float8_e4m3fn))
with flashinfer.autotune(autotune):
out = flashinfer_scaled_fp4_mm(
a_fp4,
b_fp4,
a_scale_interleaved,
b_scale_interleaved,
alpha,
dtype,
backend=backend,
)
torch.testing.assert_close(out,
expected_out.to(dtype=dtype),
atol=1e-1,
rtol=1e-1)

View File

@ -11,11 +11,9 @@ from tests.kernels.quant_utils import (FP8_DTYPE,
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
DTYPES = [torch.bfloat16, torch.float]
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
NUM_TOKENS = [1, 7, 4096]
SCALE_UBS = [True, False]
SEEDS = [0]

View File

@ -9,10 +9,9 @@ from tests.kernels.utils import opcheck
from vllm._custom_ops import scaled_int8_quant
from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
DTYPES = [torch.bfloat16, torch.float]
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
NUM_TOKENS = [1, 7, 4096]
SEEDS = [0]
SCALE = [0.1, 2.1]

View File

@ -34,8 +34,6 @@ IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
MNK_SHAPES = [
(1, 128, 128),
(1, 512, 1024),
(1, 4096, 4096),
(1, 8192, 28672),
(13, 8192, 4096),
(26, 4096, 8192),
@ -43,8 +41,6 @@ MNK_SHAPES = [
(64, 8192, 28672),
(257, 128, 4096),
(257, 4224, 4160),
(257, 4096, 4096),
(1024, 4096, 8192),
(1024, 8192, 4096),
]

View File

@ -19,10 +19,11 @@ from vllm.model_executor.layers.quantization.qqq import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_permute_scales,
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like)
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
@ -39,7 +40,7 @@ from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
USE_ATOMIC_ADD_OPTS = [False, True]
USE_FP32_REDUCE_OPTS = [False, True]
USE_FP32_REDUCE_OPTS = [True]
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 256]
@ -52,12 +53,8 @@ HQQ_SUPPORTED_GROUP_SIZES = [64]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]
DTYPES = [torch.float16, torch.bfloat16]
@ -202,17 +199,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
mnk_factors, act_order, is_k_full, use_atomic_add,
use_fp32_reduce, dtype):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
@ -231,14 +221,23 @@ def test_gptq_marlin_gemm(
if size_k % group_size != 0:
return
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
a_input = rand_data((size_m, size_k), dtype)
b_weight = rand_data((size_k, size_n), dtype)
if quant_type == scalar_types.float4_e2m1f:
if group_size != 16 or act_order:
if group_size not in [16, 32] or act_order:
return
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
b_weight.T, group_size)
if group_size == 32 and dtype == torch.float16:
return
if group_size == 16:
w_ref, marlin_q_w, marlin_s, marlin_s2 = \
rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
else:
w_ref, marlin_q_w, marlin_s = \
rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
marlin_s2 = None
g_idx = None
sort_indices = None
marlin_zp = None
@ -272,8 +271,8 @@ def test_gptq_marlin_gemm(
workspace = marlin_make_workspace_new(w_ref.device)
opcheck(torch.ops._C.gptq_marlin_gemm,
(a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx,
sort_indices, workspace, quant_type.id, a_input.shape[0],
(a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
@ -282,6 +281,7 @@ def test_gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
None,
marlin_s,
marlin_s2,
marlin_zp,
@ -418,6 +418,7 @@ def test_hqq_marlin_gemm(
a_input,
None,
marlin_w_q,
None,
marlin_s,
None,
marlin_zp,
@ -531,6 +532,7 @@ def test_marlin_gemm_subset_input():
a_input,
None,
marlin_q_w,
None,
marlin_s,
None,
marlin_zp,
@ -555,6 +557,53 @@ def test_marlin_gemm_subset_input():
assert max_diff < 0.04
@pytest.mark.parametrize("size_m", [1, 256])
def test_marlin_gemm_with_bias(size_m):
quant_type = scalar_types.uint4b8
group_size = 128
size_k, size_n = 1024, 2048
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
b_bias = rand_data((size_n, )) * 10
marlin_bias = marlin_permute_bias(b_bias)
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, False)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_bias,
marlin_s,
None,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=True,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
def test_marlin_gemm_opcheck():
size_m = 2048
size_n = 4096

View File

@ -65,9 +65,12 @@ def test_nvfp4_gemm(
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
alpha = 1. / (a_global_scale * b_global_scale)
# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
# get_ref_results unswizzles the scales internally.
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
b_scale_interleaved, a_global_scale,
b_global_scale, m, n, dtype, block_size,

View File

@ -8,15 +8,55 @@ from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0
N = [1, 2, 3, 4]
# Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1 = [
# Small, medium, large cases
(1, 8, 16),
(1, 32, 64),
(1, 128, 256),
(1, 512, 1024),
(1, 2048, 4096),
# Edge cases with specific K sizes
(1, 6144, 1024),
(1, 8192, 2048),
# Very large case
(1, 4096, 8192),
]
NKM_FACTORS_WVSPLITK = [
# Different batch sizes with key dimensions
(1, 16, 16),
(1, 64, 64),
(2, 256, 256),
(3, 1024, 1024),
(4, 4096, 4096),
# Extended K values
(1, 9216, 512),
(2, 10240, 1024),
(4, 16384, 8192),
# Minimum M constraint validation (m >= 8)
(1, 64, 8),
(2, 128, 8),
(4, 256, 8),
]
NKM_FACTORS_WVSPLITK_FP8 = [
# FP8-specific cases with K % 16 == 0
(1, 16, 16),
(1, 64, 64),
(2, 512, 512),
(3, 2048, 2048),
(4, 4096, 4096),
# Extended FP8 dimensions not covered by WVSPLITK
(1, 14336, 1024),
(2, 24576, 2048),
(4, 32768, 28672),
]
SEEDS = [0]
@pytest.mark.parametrize("n", [1]) # only test for batch size 1
@pytest.mark.parametrize("k", K)
@pytest.mark.parametrize("m", M)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
@ -34,9 +74,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
@pytest.mark.parametrize("k", K + [9216, 10240, 16384])
@pytest.mark.parametrize("m", [8] + M) # m >= 8
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
@ -54,9 +92,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(

View File

@ -60,10 +60,18 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
num_logprobs)
@pytest.mark.parametrize("M", [1, 33, 64, 512])
@pytest.mark.parametrize("N", [256, 971, 20486])
@pytest.mark.parametrize("K", [128, 496, 1024])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
MNK_FACTORS = [
(1, 256, 128),
(33, 256, 496),
(64, 971, 1024),
(64, 20486, 128),
(512, 256, 496),
(512, 20486, 1024),
]
@pytest.mark.parametrize("M,N,K", MNK_FACTORS)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
@pytest.mark.parametrize("in_dtype", get_8bit_types())
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])

View File

@ -1064,6 +1064,8 @@ def torch_experts(
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
b_bias1: Optional[torch.Tensor] = None,
b_bias2: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
@ -1108,8 +1110,13 @@ def torch_experts(
if mask.sum():
if quant_dtype is None:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
tmp1.dtype)
elif block_shape is not None:
# block quantized
assert (a_scale is not None and w1_scale is not None
@ -1117,6 +1124,8 @@ def torch_experts(
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
w1_scale[i], block_shape,
out.dtype)
if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = moe_kernel_quantize_input(
tmp2, a2_scale, quant_dtype, per_act_token_quant,
@ -1125,6 +1134,9 @@ def torch_experts(
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape,
out.dtype)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
tmp1.dtype)
else:
assert (a_scale is not None and w1_scale is not None
and w2_scale is not None)
@ -1133,6 +1145,8 @@ def torch_experts(
tmp1 = a[mask].to(f32) * scales
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
tmp1 = (tmp1 @ w1_dq).to(out.dtype)
if b_bias1 is not None:
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype)
tmp2 = SiluAndMul()(tmp1).to(out.dtype)
@ -1144,6 +1158,9 @@ def torch_experts(
tmp2 = tmp2.to(f32) * b_scale
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
out.dtype)
if apply_router_weights_on_input:
return out
@ -1157,12 +1174,14 @@ def torch_moe(a: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
b_bias1: Optional[torch.Tensor] = None,
b_bias2: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
expert_map)
b_bias1, b_bias2, expert_map)
def torch_moe_single(a, w, score, topk):

View File

@ -57,6 +57,13 @@ V1_SUPPORTED_MODELS = [
# Avoid OOM
MAX_NUM_SEQS = 4
# Once we add support for FCG in Mamba1, this list will be removed and tests
# all test cases will use enforce_eager=False
ENFORCE_EAGER_MODELS_V1 = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
]
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@ -94,13 +101,19 @@ def test_models(
example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS:
enforce_eager = False
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
if model in ENFORCE_EAGER_MODELS_V1:
enforce_eager = True
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=enforce_eager,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
@ -418,3 +431,65 @@ def test_full_cuda_graph(
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_1="vllm-v1",
)
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_fp32_state(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
num_logprobs: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
with hf_runner(model) as hf_model:
if model not in HF_UNSUPPORTED_MODELS:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32",
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if hf_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_1="vllm-v1",
)

View File

@ -18,7 +18,7 @@ from tests.models.utils import EmbedModelInfo, RerankModelInfo
# - Different model results in differences more than 1e-3
# 1e-4 is a good tolerance threshold
MTEB_EMBED_TASKS = ["STS12"]
MTEB_EMBED_TOL = 1e-4
MTEB_EMBED_TOL = 0.02
# See #19344
MTEB_RERANK_TASKS = ["NFCorpus"]
@ -175,6 +175,7 @@ def mteb_test_embed_models(hf_runner,
with vllm_runner(model_info.name,
runner="pooling",
max_model_len=None,
enforce_eager=True,
**vllm_extra_kwargs) as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
@ -198,6 +199,7 @@ def mteb_test_embed_models(hf_runner,
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
st_dtype = next(hf_model.model.parameters()).dtype
print("Model:", model_info.name)
print("VLLM:", vllm_dtype, vllm_main_score)
print("SentenceTransformers:", st_dtype, st_main_score)
print("Difference:", st_main_score - vllm_main_score)
@ -286,6 +288,7 @@ def mteb_test_rerank_models(hf_runner,
runner="pooling",
max_model_len=None,
max_num_seqs=8,
enforce_eager=True,
**vllm_extra_kwargs) as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
@ -304,6 +307,7 @@ def mteb_test_rerank_models(hf_runner,
st_main_score, st_dtype = mteb_test_rerank_models_hf(
hf_runner, model_info.name, hf_model_callback)
print("Model:", model_info.name)
print("VLLM:", vllm_dtype, vllm_main_score)
print("SentenceTransformers:", st_dtype, st_main_score)
print("Difference:", st_main_score - vllm_main_score)

View File

@ -151,7 +151,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
trust_remote_code=True),
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
min_transformers_version="4.55.1",
min_transformers_version="4.56.0",
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
{"1b": "bigscience/bloomz-1b1"}),
@ -227,7 +227,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
min_transformers_version="4.55.1",
min_transformers_version="4.56.0",
extras={
"tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random", # noqa: E501

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import uuid
from pathlib import Path
import numpy as np
@ -72,3 +73,22 @@ def test_hash_non_contiguous_array():
hasher = MultiModalHasher
# Both should be hashable and produce the same hashes
assert hasher.hash_kwargs(data=arr) == hasher.hash_kwargs(data=arr_c)
def test_hash_image_exif_id():
# Test that EXIF ImageId tag can be used to store UUID
# and the hasher will use that instead of the image data.
image1 = image2 = Image.new("1", size=(10, 20))
id = uuid.uuid4()
image1.getexif()[Image.ExifTags.Base.ImageID] = id
image2 = Image.open(ASSETS_DIR / "image1.png")
image2.getexif()[Image.ExifTags.Base.ImageID] = "Not a UUID"
image2a = Image.open(ASSETS_DIR / "image1.png")
hasher = MultiModalHasher
# first image has UUID in ImageID, so it should hash to that UUID
assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs(
image=id.bytes)
# second image has non-UUID in ImageID, so it should hash to the image data
assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs(
image=image2a)

View File

@ -148,6 +148,32 @@ async def test_fetch_image_local_files(image_url: str):
f"file://{temp_dir}/../{os.path.basename(image_url)}")
@pytest.mark.asyncio
async def test_fetch_image_local_files_with_space_in_name():
image_url = TEST_IMAGE_URLS[0]
connector = MediaConnector()
with TemporaryDirectory() as temp_dir:
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
origin_image = connector.fetch_image(image_url)
filename = "file name with space.jpg"
origin_image.save(os.path.join(temp_dir, filename),
quality=100,
icc_profile=origin_image.info.get('icc_profile'))
try:
image_async = await local_connector.fetch_image_async(
f"file://{temp_dir}/{filename}")
image_sync = local_connector.fetch_image(
f"file://{temp_dir}/{filename}")
except FileNotFoundError as e:
pytest.fail(
"Failed to fetch image with space in name: {}".format(e))
# Check that the images are equal
assert not ImageChops.difference(image_sync, image_async).getbbox()
@pytest.mark.asyncio
async def test_fetch_image_error_conversion():
connector = MediaConnector()

View File

@ -10,7 +10,7 @@ cd /vllm-workspace/
# uninstall vllm
pip3 uninstall -y vllm
# restore the original files
mv test_docs/vllm ./vllm
mv src/vllm ./vllm
# remove all compilers
apt remove --purge build-essential -y

View File

@ -4,7 +4,7 @@
import pytest
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend

View File

View File

@ -0,0 +1,406 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# Helper MLP for testing
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
def forward(self, x):
return self.fc2(self.fc1(x))
def _create_vllm_config(compilation_config: CompilationConfig,
max_num_seqs: int = 8) -> MagicMock:
mock_config = MagicMock(spec=VllmConfig)
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
mock_config.parallel_config = ParallelConfig()
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.level == CompilationLevel.PIECEWISE:
compilation_config.set_splitting_ops_for_v1()
return mock_config
class TestCudagraphDispatcher:
@pytest.mark.parametrize(
"params",
[
# Test case 0: Full CG for mixed batches, no separate routine
{
"case_id": 0,
"cudagraph_mode": "FULL",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
# Test case 1: Full CG for uniform batches, piecewise for mixed
{
"case_id": 1,
"cudagraph_mode": "FULL_AND_PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
# Test case 2: Full CG for uniform batches, no CG for mixed
{
"case_id": 2,
"cudagraph_mode": "FULL_DECODE_ONLY",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
# Test case 3: Piecewise for all
{
"case_id": 3,
"cudagraph_mode": "PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
])
def test_dispatcher(self, params):
# Setup dispatcher
comp_config = CompilationConfig(
cudagraph_mode=params["cudagraph_mode"],
level=params["compilation_level"],
cudagraph_capture_sizes=[1, 8])
config = _create_vllm_config(comp_config, max_num_seqs=8)
dispatcher = CudagraphDispatcher(config)
dispatcher.initialize_cudagraph_keys(
cudagraph_mode=comp_config.cudagraph_mode,
uniform_decode_query_len=1)
# Verify the key is initialized correctly
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
# Test dispatch logic
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact)
if params["cudagraph_mode"] == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact
else:
assert rt_mode == CUDAGraphMode.NONE
# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
if params["cudagraph_mode"] == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.non_uniform
elif params["cudagraph_mode"] in [
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
]:
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif params["cudagraph_mode"] == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.non_uniform
else:
assert rt_mode == CUDAGraphMode.NONE
# 3. No key match
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_no_match)
assert rt_mode == CUDAGraphMode.NONE
assert key is None
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
def setup_method(self):
self.vllm_config = _create_vllm_config(CompilationConfig())
self.model = SimpleMLP().to("cuda")
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
self.input_tensor = torch.randn(1, 10, device="cuda")
@create_new_process_for_each_test("spawn")
def test_capture_and_replay(self):
wrapper = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
batch_descriptor = BatchDescriptor(num_tokens=10)
# 0. global warmup
with set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None):
wrapper(self.input_tensor)
# 1. Capture
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
batch_descriptor=batch_descriptor),\
patch("torch.cuda.graph",
wraps=torch.cuda.graph) as mock_cuda_graph:
output1 = wrapper(self.input_tensor)
# capturing phase should generate a zero output
assert torch.allclose(output1, torch.zeros_like(output1))
mock_cuda_graph.assert_called_once()
assert batch_descriptor in wrapper.concrete_cudagraph_entries
entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
assert entry.cudagraph is not None
# 2. Replay
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
batch_descriptor=batch_descriptor),\
patch.object(entry.cudagraph, 'replay',
wraps=entry.cudagraph.replay) as mock_replay:
output2 = wrapper(self.input_tensor)
mock_replay.assert_called_once()
# Compare with eager output
eager_output = self.model(self.input_tensor)
torch.testing.assert_close(eager_output, output2)
@create_new_process_for_each_test("spawn")
def test_bypass_on_mode_mismatch(self):
wrapper = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
batch_descriptor = BatchDescriptor(num_tokens=10)
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=batch_descriptor), \
patch('torch.cuda.graph',
wraps=torch.cuda.graph) as mock_cuda_graph, \
patch.object(self.model, 'forward',
wraps=self.model.forward) as mock_forward:
wrapper(self.input_tensor)
mock_cuda_graph.assert_not_called()
mock_forward.assert_called_once()
assert not wrapper.concrete_cudagraph_entries
@create_new_process_for_each_test("spawn")
def test_bypass_on_mode_none(self):
wrapper = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
batch_descriptor = BatchDescriptor(num_tokens=10)
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=batch_descriptor), \
patch('torch.cuda.graph',
wraps=torch.cuda.graph) as mock_cuda_graph:
wrapper(self.input_tensor)
mock_cuda_graph.assert_not_called()
assert not wrapper.concrete_cudagraph_entries
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
def setup_method(self):
# only FULL mode for non-uniform batches
self.comp_config = CompilationConfig(level=CompilationLevel.PIECEWISE,
cudagraph_mode="FULL",
cudagraph_capture_sizes=[10, 20])
self.vllm_config = _create_vllm_config(self.comp_config)
self.dispatcher = CudagraphDispatcher(self.vllm_config)
self.dispatcher.initialize_cudagraph_keys(
self.comp_config.cudagraph_mode, uniform_decode_query_len=1)
def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode,
batch_descriptor):
"""Helper to run a single call and monitor the action."""
with patch('torch.cuda.graph',
wraps=torch.cuda.graph) as mock_graph_context, \
patch.object(wrapper, 'runnable',
wraps=wrapper.runnable) as mock_runnable:
entry = wrapper.concrete_cudagraph_entries.get(
batch_descriptor, None)
context = set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=runtime_mode,
batch_descriptor=batch_descriptor)
mock_replay = MagicMock()
if entry and entry.cudagraph:
with context, \
patch.object(entry.cudagraph, 'replay',
new_callable=MagicMock) as mock_replay:
wrapper(input_tensor)
else:
with context:
wrapper(input_tensor)
if mock_graph_context.called:
# note that this is globally mocked, so it will be detected
# even whether called by the inner or outer wrapper
return "capture_global"
if mock_replay.called:
# only for outer wrapper
return "replay"
if mock_runnable.call_count > 0:
# only for outer wrapper
return "bypass"
return "unknown"
@create_new_process_for_each_test("spawn")
def test_capture_replay_bypass_logic(self):
model = SimpleMLP().to("cuda")
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
CUDAGraphMode.FULL)
max_bs = 16
persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
input_1 = persistent_input_buffer[:1]
input_2 = persistent_input_buffer[:2]
input_3 = persistent_input_buffer[:3]
desc_1 = BatchDescriptor(num_tokens=1)
desc_2 = BatchDescriptor(num_tokens=2)
desc_3_unseen = BatchDescriptor(num_tokens=3)
# 0. global warmup
with set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None):
full_wrapper(input_1)
rt_mode, key = self.dispatcher.dispatch(desc_1)
# 1. Capture first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
key)
assert action == "capture_global"
# 2. Replay first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
key)
assert action == "replay"
rt_mode, key = self.dispatcher.dispatch(desc_2)
# 3. Capture second shape
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode,
key)
assert action == "capture_global"
# 4. Replay second shape
action = self._run_and_monitor_call(full_wrapper, input_2,
CUDAGraphMode.FULL, desc_2)
assert action == "replay"
# 5. Bypass if no key match
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
assert rt_mode == CUDAGraphMode.NONE
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode,
key)
assert action == "bypass"
# capture unseen shape is not allowed after disable
set_cudagraph_capturing_enabled(False)
with pytest.raises(RuntimeError):
self._run_and_monitor_call(full_wrapper, input_3,
CUDAGraphMode.FULL, desc_3_unseen)
set_cudagraph_capturing_enabled(True)
@create_new_process_for_each_test("spawn")
def test_nested_wrappers(self):
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
model = SimpleMLP().to("cuda")
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
CUDAGraphMode.FULL)
input_1 = torch.randn(1, 10, device="cuda")
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
inner_model = SimpleMLP().to("cuda")
piecewise_wrapper = CUDAGraphWrapper(inner_model, self.vllm_config,
CUDAGraphMode.PIECEWISE)
inner_model.forward = MagicMock(wraps=inner_model.forward)
outer_model = SimpleMLP().to("cuda")
# When outer model is called, it calls the piecewise_wrapper
outer_model.forward = MagicMock(wraps=outer_model.forward,
side_effect=piecewise_wrapper)
full_wrapper = CUDAGraphWrapper(outer_model, self.vllm_config,
CUDAGraphMode.FULL)
desc_1 = BatchDescriptor(num_tokens=1)
# 0. global warmup
with set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None):
full_wrapper(input_1)
# --- Test runtime mode FULL---
# Run with FULL mode context. Expect outer wrapper to capture.
# The inner mock should be called once inside the graph capture.
outer_model.forward.reset_mock()
inner_model.forward.reset_mock()
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.FULL, desc_1)
assert action == "capture_global"
assert outer_model.forward.call_count == 1
assert inner_model.forward.call_count == 1
# Run again. Expect outer wrapper to replay.
# The outer model should NOT be called because the whole graph
# is replayed.
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.FULL, desc_1)
assert action == "replay"
assert outer_model.forward.call_count == 1 # No new call
assert inner_model.forward.call_count == 1
# --- Test runtime mode PIECEWISE ---
outer_model.forward.reset_mock()
inner_model.forward.reset_mock()
# Run with PIECEWISE mode context.
# Expect outer wrapper to bypass and call inner wrapper.
# Inner wrapper should capture.
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.PIECEWISE, desc_1)
assert action == "capture_global"
assert outer_model.forward.call_count == 1
assert inner_model.forward.call_count == 1
# Run again with PIECEWISE.
# Outer bypasses, inner replays.
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.PIECEWISE, desc_1)
assert action == "bypass"
assert outer_model.forward.call_count == 2
assert inner_model.forward.call_count == 1

View File

@ -0,0 +1,187 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from vllm import LLM
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
combo_cases_1 = [
("FA3", "FULL", True),
("FA3", "FULL_AND_PIECEWISE", True),
("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
("FA2", "FULL_AND_PIECEWISE", True),
("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
("FlashInfer", "FULL_AND_PIECEWISE", True),
]
@pytest.mark.parametrize("combo_case", combo_cases_1)
def test_backend_and_cudagraph_mode_combo(combo_case):
backend_name, cudagraph_mode, supported = combo_case
if backend_name == "FlashInfer":
try:
import flashinfer # noqa: F401
except ImportError:
pytest.skip("FlashInfer is not installed")
backend_config = backend_configs[backend_name]
# Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
!= current_platform.get_device_capability():
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}
with temporary_environ(env_vars), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(Exception))
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=256,
trust_remote_code=True,
gpu_memory_utilization=0.45,
max_model_len=1024,
compilation_config=CompilationConfig(
level=3, cudagraph_mode=cudagraph_mode))
llm.generate(["Hello, my name is"] * 10)
try:
llm = weakref.proxy(llm)
del llm
except UnboundLocalError:
pass
wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)
# test cudagraph_mode with different compilation level.
# (backend_name, cudagraph_mode, compilation_level, supported)
combo_cases_2 = [
("FA2", "FULL", 0, True), # no compilation + full cudagraph
("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph
("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph
("FA2", "PIECEWISE", 3,
True), # piecewise compilation + piecewise cudagraph
("FA2", "FULL_AND_PIECEWISE", 0,
False), # piecewise cudagraph not supported without piecewise compilation
("FA2", "FULL_AND_PIECEWISE", 3, True),
("FA2", "FULL_DECODE_ONLY", 0, True),
("FA2", "FULL_DECODE_ONLY", 3, True),
("FA2", "NONE", 0, True), # no compilation + no cudagraph
("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph
]
@pytest.mark.parametrize("combo_case", combo_cases_2)
def test_cudagraph_compilation_combo(combo_case):
backend_name, cudagraph_mode, compilation_level, supported\
= combo_case
env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}
with temporary_environ(env_vars), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(Exception))
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=256,
trust_remote_code=True,
gpu_memory_utilization=0.45,
max_model_len=1024,
compilation_config=CompilationConfig(
level=compilation_level, cudagraph_mode=cudagraph_mode))
llm.generate(["Hello, my name is"] * 10)
try:
llm = weakref.proxy(llm)
del llm
except UnboundLocalError:
pass
finally:
wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)

View File

@ -162,6 +162,12 @@ def test_eagle_correctness(
mm_enabled: bool,
attn_backend: str,
):
if attn_backend == "TREE_ATTN":
# TODO: Fix this flaky test
pytest.skip(
"TREE_ATTN is flaky in the test disable for now until it can be "
"reolved (see https://github.com/vllm-project/vllm/issues/22922)")
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''

View File

@ -13,6 +13,7 @@ from vllm.assets.image import ImageAsset
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
@ -398,3 +399,89 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch):
# Test 3: Verify healthy engine still works after mock
await engine.check_health()
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_abort_final_output(
monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
):
"""Test that abort() returns a final output with correct information."""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
request_id = "test-abort-final-output"
# Start a long-running request
sampling_params = SamplingParams(
max_tokens=3000, # Long enough to allow abort
ignore_eos=True,
output_kind=output_kind,
temperature=0.5,
seed=42,
)
outputs: list[RequestOutput] = []
generated = asyncio.create_task(
collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params,
outputs))
# Let it generate some tokens
await asyncio.sleep(0.5)
# Abort the request
await engine.abort(request_id)
# Wait for generation to complete and return final output
final_output = await generated
# Verify we got a final output
assert final_output is not None
assert final_output.finished
assert len(final_output.outputs) == 1
assert final_output.outputs[0].finish_reason == "abort"
assert final_output.outputs[0].stop_reason is None
# Verify num_cached_tokens is set correctly
assert hasattr(final_output, 'num_cached_tokens')
assert final_output.num_cached_tokens >= 0
# If we got intermediate outputs, verify they are consistent
if output_kind == RequestOutputKind.DELTA:
# For DELTA, sum all intermediate tokens should <= final tokens
token_count = sum(
len(output.outputs[0].token_ids) for output in outputs)
assert token_count > 0
assert len(final_output.outputs[0].token_ids) == 0
else:
# For FINAL_ONLY, we should only get the final output
assert len(outputs) == 0
assert len(final_output.outputs[0].token_ids) > 0
assert not engine.output_processor.has_unfinished_requests()
async def collect_outputs(
engine: AsyncLLM,
request_id: str,
prompt: PromptType,
sampling_params: SamplingParams,
outputs_list: list[RequestOutput],
) -> Optional[RequestOutput]:
"""Helper to collect outputs and return the final one."""
final_output: Optional[RequestOutput] = None
async for output in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
if not output.finished:
outputs_list.append(output)
final_output = output
return final_output

View File

@ -4,6 +4,8 @@ import asyncio
import os
import threading
import time
import traceback
from typing import Optional, cast
import openai # use the official client for correctness check
import pytest
@ -41,12 +43,15 @@ class MultinodeInternalLBServerManager:
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.servers: list[Optional[tuple[RemoteOpenAIServer,
list[str]]]] = [None] * (dp_size //
dp_per_node)
self.server_threads: list[threading.Thread] = []
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for multi-node internal LB mode."""
for rank in range(0, self.dp_size, self.dp_per_node):
for server_idx, rank in enumerate(
range(0, self.dp_size, self.dp_per_node)):
# Create server args for this specific rank
server_args = self.base_server_args.copy()
@ -87,7 +92,7 @@ class MultinodeInternalLBServerManager:
])
# Use a thread to start each server to allow parallel initialization
def start_server(r: int, sargs: list[str]):
def start_server(sidx: int, r: int, sargs: list[str]):
gpus_per_node = self.tp_size * self.dp_per_node
try:
# Start the server
@ -110,13 +115,14 @@ class MultinodeInternalLBServerManager:
f"{self.api_server_count} API servers")
else:
print(f"Headless node (rank {r}) started successfully")
self.servers.append((server, sargs))
self.servers[sidx] = (server, sargs)
except Exception as e:
print(f"Failed to start server rank {r}: {e}")
traceback.print_exc()
raise
thread = threading.Thread(target=start_server,
args=(rank, server_args))
args=(server_idx, rank, server_args))
thread.start()
self.server_threads.append(thread)
@ -128,18 +134,20 @@ class MultinodeInternalLBServerManager:
# Give servers additional time to fully initialize and coordinate
time.sleep(3)
if len(self.servers) != self.dp_size // self.dp_per_node:
if not all(self.servers):
raise Exception("Servers failed to start")
return self.servers
return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers)
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
if server := self.servers.pop():
try:
server[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
traceback.print_exc()
class APIOnlyServerManager:
@ -157,7 +165,8 @@ class APIOnlyServerManager:
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.servers: list[Optional[tuple[RemoteOpenAIServer,
list[str]]]] = [None] * 2
self.server_threads: list[threading.Thread] = []
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
@ -209,7 +218,7 @@ class APIOnlyServerManager:
server.__enter__()
print(f"API-only server started successfully with "
f"{self.api_server_count} API servers")
self.servers.append((server, api_server_args))
self.servers[0] = (server, api_server_args)
except Exception as e:
print(f"Failed to start API-only server: {e}")
raise
@ -231,7 +240,7 @@ class APIOnlyServerManager:
server.__enter__()
print(f"Headless engines server started successfully with "
f"{self.dp_size} engines")
self.servers.append((server, engines_server_args))
self.servers[1] = (server, engines_server_args)
except Exception as e:
print(f"Failed to start headless engines server: {e}")
raise
@ -253,18 +262,20 @@ class APIOnlyServerManager:
# Give servers additional time to fully initialize and coordinate
time.sleep(3)
if len(self.servers) != 2:
if not all(self.servers):
raise Exception("Both servers failed to start")
return self.servers
return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers)
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop both server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
if server := self.servers.pop():
try:
server[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
traceback.print_exc()
@pytest.fixture(scope="module")
@ -560,7 +571,7 @@ async def test_api_only_multinode_dp_completion(
assert len(results) == num_requests
assert all(completion is not None for completion in results)
_, api_server_args = api_only_servers[0]
api_server, api_server_args = api_only_servers[0]
api_server_count = (
api_server_args.count('--api-server-count')
and api_server_args[api_server_args.index('--api-server-count') + 1]
@ -569,7 +580,6 @@ async def test_api_only_multinode_dp_completion(
f"engines on headless server (API server count: {api_server_count})")
# Check request balancing via Prometheus metrics
api_server = api_only_servers[0][0]
check_request_balancing(api_server, DP_SIZE)

View File

@ -772,6 +772,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
head_dim=hf_config.mamba_d_head,
rms_norm_eps=hf_config.rms_norm_eps,
activation=hf_config.hidden_act,
cache_config=cache_config,
model_config=model_config,
prefix=key,
)
# suppress var not used error

View File

@ -319,38 +319,6 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
repetition_penalties)
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor) -> None:
"""Advance a step on GPU for existing inputs for a multi-step runner"""
return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
block_size, input_tokens,
sampled_token_ids,
input_positions, seq_lens,
slot_mapping, block_tables)
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
block_table_bound: torch.Tensor) -> None:
return torch.ops._C.advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables,
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
block_table_bound)
# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
input: torch.Tensor,
@ -452,6 +420,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _gptq_marlin_gemm_fake(a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_bias: Optional[torch.Tensor],
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
@ -1048,6 +1017,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
def gptq_marlin_gemm(a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_bias: Optional[torch.Tensor],
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
@ -1062,7 +1032,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_bias, b_scales,
global_scale, b_zeros, g_idx, perm,
workspace, b_q_type.id, size_m,
size_n, size_k, is_k_full,
@ -1540,7 +1510,9 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor, b_scales: torch.Tensor,
b_qweight: torch.Tensor,
b_bias: Optional[torch.Tensor],
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
@ -1556,11 +1528,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx,
perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded,
topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep,
b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add,
use_fp32_reduce, is_zp_float)
input, output, b_qweight, b_bias, b_scales, global_scale, b_qzeros,
g_idx, perm, workspace, sorted_token_ids, expert_ids,
num_tokens_past_padded, topk_weights, moe_block_size, top_k,
mul_topk_weights, is_ep, b_q_type.id, size_m, size_n, size_k,
is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

View File

@ -101,11 +101,6 @@ class AttentionBackend(ABC):
) -> None:
raise NotImplementedError
def advance_step(self, model_input: "ModelRunnerInputBase",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int) -> None:
raise NotImplementedError
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)

View File

@ -35,8 +35,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner import ModelInputForGPUBuilder
logger = init_logger(__name__)
@ -326,79 +325,6 @@ class DifferentialFlashAttentionMetadata(AttentionMetadata):
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class DifferentialFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]):

View File

@ -32,8 +32,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner import ModelInputForGPUBuilder
logger = init_logger(__name__)
@ -309,79 +308,6 @@ class FlashAttentionMetadata(AttentionMetadata):
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):

View File

@ -51,8 +51,7 @@ from vllm.utils.flashinfer import use_trtllm_attention
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner import ModelInputForGPUBuilder
class FlashInferBackend(AttentionBackend):
@ -428,7 +427,7 @@ class FlashInferMetadata(AttentionMetadata):
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None
# used for GPU in-place advance_step
# used for GPU operations
seq_lens_tensor: Optional[torch.Tensor] = None
block_table_bound: Optional[torch.Tensor] = None
@ -587,66 +586,6 @@ class FlashInferMetadata(AttentionMetadata):
return None
return self
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
if turn_prefills_into_decodes:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
# Flashinfer doesn't support speculative decoding + chunked-prefill
# + multi-step scheduling yet.
assert self.decode_query_len == 1
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens_tensor is not None
assert num_seqs > 0
assert num_queries > 0
assert model_input.attn_metadata is not None
assert sampled_token_ids is not None
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
# Update GPU tensors
ops.advance_step_flashinfer(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=model_input.input_tokens,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables,
paged_kv_indices=self.paged_kv_indices,
paged_kv_indptr=self.paged_kv_indptr,
paged_kv_last_page_len=self.paged_kv_last_page_len,
block_table_bound=self.block_table_bound)
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

View File

@ -3,7 +3,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import torch
@ -18,9 +18,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
class FlashMLABackend(MLACommonBackend):
@ -62,16 +59,6 @@ class FlashMLAMetadata(MLACommonMetadata):
self.decode_num_splits
return decode_metadata
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
raise NotImplementedError(
"advance_step is not implemented for FlashMLA")
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):

View File

@ -234,8 +234,7 @@ except ImportError:
flash_attn_varlen_func = None
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner import ModelInputForGPUBuilder
is_hip = current_platform.is_rocm()
@ -631,90 +630,6 @@ class MLACommonMetadata(AttentionMetadata):
is_profile_run=self.is_profile_run)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
if turn_prefills_into_decodes:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
self._ops_advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions)
def _ops_advance_step(self, num_seqs: int, num_queries: int,
block_size: int, input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor) -> None:
# here we use advance_step_flashinfo to update the paged_kv_* tensors
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
"""

View File

@ -15,8 +15,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner import (ModelInputForGPUBuilder)
from vllm.utils import async_tensor_h2d
# Placeholder attention backend for models like Mamba and pooling models that
@ -201,65 +200,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
assert not turn_prefills_into_decodes, \
("Multi-Step + Chunked-Prefill is not supported for attention-free"
"models. turn_prefills_into_decodes is a "
"Multi-Step + Chunked-Prefill specific parameter.")
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
# Update sequences, masking off entries greater than num_queries
device = self.seq_lens_tensor.device
mask = torch.arange(self.seq_lens_tensor.size(0),
device=device) < num_queries
self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype)
if sampled_token_ids is not None:
model_input.input_tokens.masked_scatter_(
mask, sampled_token_ids[:num_queries])
class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):

View File

@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional, Type, Union
import torch
import vllm._custom_ops as ops
import vllm.envs as envs
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
@ -107,26 +106,6 @@ class AiterMLAMetadata(MLACommonMetadata):
return self._cached_decode_metadata
def _ops_advance_step(self, num_seqs: int, num_queries: int,
block_size: int, input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor) -> None:
ops.advance_step_flashinfer(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables,
paged_kv_indices=self.paged_kv_indices,
paged_kv_indptr=self.paged_kv_indptr,
paged_kv_last_page_lens=self.paged_kv_last_page_lens,
block_table_bound=self.block_table_bound)
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]

View File

@ -4,7 +4,7 @@
import itertools
from dataclasses import dataclass
from functools import cache
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import torch
@ -23,9 +23,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
_PARTITION_SIZE_ROCM = 256
@ -261,69 +258,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
self._cached_decode_metadata.query_start_loc = qs - qs[0]
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with rocm_flash_attn yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class ROCmFlashAttentionMetadataBuilder(
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):

View File

@ -15,7 +15,7 @@ import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher
import vllm.envs as envs
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
@ -277,9 +277,6 @@ def split_graph(graph: fx.GraphModule,
return split_gm, outputs
# we share the global graph pool among all the backends
global_graph_pool = None
compilation_start_time = 0.0
@ -339,14 +336,37 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None)
# Lazy import here to avoid circular import
from .cuda_graph import CUDAGraphOptions
from .cuda_piecewise_backend import PiecewiseBackend
piecewise_backend = resolve_obj_by_qualname(
current_platform.get_piecewise_backend_cls())
self.module.__dict__[target] = piecewise_backend(
submod, self.vllm_config, self.graph_pool, index,
piecewise_backend = PiecewiseBackend(
submod, self.vllm_config, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_dynamic_shape, self.vllm_backend)
if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())
# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
self.module.__dict__[target] = static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
graph_pool=self.graph_pool,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
weak_ref_output=piecewise_backend.is_last_graph))
else:
self.module.__dict__[target] = piecewise_backend
compilation_counter.num_piecewise_capturable_graphs_seen += 1
return output
@ -413,9 +433,7 @@ class VllmBackend:
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = current_platform.graph_pool_handle()
global_graph_pool = current_platform.get_global_graph_pool()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
@ -585,7 +603,7 @@ class VllmBackend:
self._called = True
if not self.compilation_config.use_cudagraph or \
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm

View File

@ -1,72 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Protocol
import torch.fx as fx
from vllm.compilation.backends import VllmBackend
from vllm.config import VllmConfig
class AbstractPiecewiseBackend(Protocol):
"""
PiecewiseBackend interface that allows platforms to extend
piecewise static graph.
"""
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend, **kwargs):
"""
Initializes the PiecewiseBackend class with compilation and
execution-related configurations.
This class handles piecewise compilation, graph capturing,
and dispatching for specific input shapes.
Args:
graph (fx.GraphModule): The graph represented in fx.
vllm_config (VllmConfig): Global configuration for vLLM.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
piecewise_compile_index (int):
Index of the current piecewise subgraph.
total_piecewise_compiles (int):
Total number of piecewise-compiled graphs.
sym_shape_indices (list[int]):
Indices of symbolic shape.
compiled_graph_for_general_shape (Callable):
Callable that executes the graph compiled for general shapes.
vllm_backend (VllmBackend):
Backend compiler that manages compilation and graph runtime
for vLLM.
Keyword Args:
kwargs: Additional keyword arguments reserved for future
extensions or custom platforms.
"""
raise NotImplementedError
def __call__(self, *args) -> Any:
"""Executes the compiled graph for given input args.
If this is the first invocation, executes the general compiled graph
and initiates the compilation process tracking. For subsequent calls,
dynamically dispatches execution to either a compiled graph or a static
graph based on the input shape.
Args:
*args: Variable length input arguments to be passed into the
graph. The symbolic shape is expected to be in position
`sym_shape_indices[0]`.
Returns:
Any: Output of the executed graph. This can be from the general
compiled graph, a specialized compiled version for the given shape,
or a replayed static graph.
"""
raise NotImplementedError

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Protocol
from vllm.config import CUDAGraphMode, VllmConfig
class AbstractStaticGraphWrapper(Protocol):
"""
StaticGraphWrapper interface that allows platforms to wrap a callable
to be captured as a static graph.
"""
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
Args:
runnable (Callable): The callable to be wrapped and captured.
vllm_config (VllmConfig): Global configuration for vLLM.
runtime_mode (CUDAGraphMode): The style of the static
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
"""
raise NotImplementedError
def __call__(self, *args, **kwargs) -> Any:
"""
Executes the wrapped callable.
If the current runtime mode in the ForwardContext matches the runtime
mode of this instance, it replays the CUDAGraph or captures it using
the callable if it hasn't been captured yet. Otherwise, it calls the
original callable directly.
Args:
*args: Variable length input arguments to be passed into the
callable.
**kwargs: Keyword arguments to be passed into the callable.
Returns:
Any: Output of the executed callable.
"""
raise NotImplementedError

View File

@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import weak_ref_tensors
logger = init_logger(__name__)
@dataclasses.dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
@dataclasses.dataclass
class CUDAGraphOptions:
debug_log_enable: bool = True
gc_disable: bool = False
weak_ref_output: bool = True
class CUDAGraphWrapper:
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
provide attribute access to the underlying `runnable` via `__getattr__`.
The workflow of this wrapper in the cudagraph dispatching is as follows:
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
PIECEWISE).
2. At runtime, the wrapper receives a runtime_mode and a
batch_descriptor(key) from the forward context and blindly trust them
for cudagraph dispatching.
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
wrapper, just call the runnable directly.
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
the wrapper will perform cudagraph capture(if key does not exist, create
a new entry and cache it) or replay (if key exists in the cache).
Note: CUDAGraphWrapper does not store persistent buffers or copy any
runtime inputs into that buffers for replay. We assume implementing them
is done outside of the wrapper. That is because we do not make any
assumption on the dynamic shape (batch size) of the runtime inputs, as a
trade-off for staying orthogonal to compilation logic. Nevertheless,
tracing and checking the input addresses to be consistent during replay is
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
def __init__(self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
graph_pool: Any = None,
cudagraph_options: Optional[CUDAGraphOptions] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.graph_pool = graph_pool
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
if self.graph_pool is None:
self.graph_pool = current_platform.get_global_graph_pool()
if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()
self.cudagraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\
= {}
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
if cudagraph_runtime_mode == CUDAGraphMode.NONE or \
cudagraph_runtime_mode != self.runtime_mode:
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without cudagraphs.
# We do not trigger capture/replay if the runtime mode is not
# matches. This enables properly dispatching to the correct
# CUDAGraphWrapper when nesting multiple instances with different
# runtime modes.
return self.runnable(*args, **kwargs)
if batch_descriptor not in self.concrete_cudagraph_entries:
# create a new entry for this batch descriptor
self.concrete_cudagraph_entries[batch_descriptor] = \
CUDAGraphEntry(batch_descriptor=batch_descriptor)
entry = self.concrete_cudagraph_entries[batch_descriptor]
if entry.cudagraph is None:
if self.cudagraph_options.debug_log_enable:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
logger.debug("Capturing a cudagraph on (%s,%s)",
self.runtime_mode.name, entry.batch_descriptor)
# validate that cudagraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if self.cudagraph_options.gc_disable:
# during every model forward for piecewise cudagraph
# mode, we will capture many pieces of cudagraphs
# (roughly one per layer). running gc again and again
# across layers will make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.cudagraph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for cudagraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}")
entry.cudagraph.replay()
return entry.output

View File

@ -2,21 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch
from typing import Any, Callable
import torch
import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
logger = init_logger(__name__)
@ -24,44 +18,29 @@ logger = init_logger(__name__)
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
need_to_compile: bool # the size is in compile_sizes
use_cudagraph: bool # the size is in cudagraph_capture_sizes
compiled: bool = False
runnable: Callable = None # type: ignore
num_finished_warmup: int = 0
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
class CUDAPiecewiseBackend:
class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: list[int],
piecewise_compile_index: int, total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
It mainly handles the compilation of static shapes and
dispatching based on runtime shape.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
Independently, we will capture cudagraph for different shapes.
If a shape needs both compilation and cudagraph, we will
compile it first, and then capture cudagraph.
"""
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
@ -70,11 +49,10 @@ class CUDAPiecewiseBackend:
self.is_last_graph = (
piecewise_compile_index == total_piecewise_compiles - 1)
self.is_full_graph = total_piecewise_compiles == 1
self.compile_sizes: set[int] = set(
self.compilation_config.compile_sizes)
self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()
self.first_run_finished = False
@ -84,18 +62,18 @@ class CUDAPiecewiseBackend:
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# the entries for different shapes that we need to either
# compile or capture cudagraph
# the entries for different shapes that we need to compile
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
# We only keep compilation management inside this class directly.
for shape in self.compile_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.cudagraph_capture_sizes,
runnable=self.compiled_graph_for_general_shape,
)
def check_for_ending_compilation(self):
@ -112,16 +90,14 @@ class CUDAPiecewiseBackend:
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)
entry = self.concrete_size_entries[runtime_shape]
if entry.runnable is None:
entry.runnable = self.compiled_graph_for_general_shape
if entry.need_to_compile and not entry.compiled:
if not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
@ -138,81 +114,4 @@ class CUDAPiecewiseBackend:
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
# Skip CUDA graphs if this entry doesn't use them OR
# if we're supposed to skip them globally
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
if not entry.use_cudagraph or skip_cuda_graphs:
return entry.runnable(*args)
if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
runtime_shape)
return entry.runnable(*args)
if self.is_first_graph:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
runtime_shape)
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)
entry.cudagraph.replay()
return entry.output
return entry.runnable(*args)

View File

@ -37,3 +37,21 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
if context_manager is not None:
context_manager.__exit__(None, None, None)
context_manager = None
cudagraph_capturing_enabled: bool = True
def validate_cudagraph_capturing_enabled():
# used to monitor whether an cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
global cudagraph_capturing_enabled
if not cudagraph_capturing_enabled:
raise RuntimeError("CUDA graph capturing detected at an inappropriate "
"time. This operation is currently disabled.")
def set_cudagraph_capturing_enabled(enabled: bool):
global cudagraph_capturing_enabled
cudagraph_capturing_enabled = enabled

View File

@ -11,7 +11,8 @@ from typing import Callable, Optional
import torch
import vllm.envs as envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.config import (CompilationLevel, CUDAGraphMode,
get_current_vllm_config)
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -115,8 +116,8 @@ class TorchCompileWrapperWithCustomDispatcher:
except Exception:
pass
if self.vllm_config.compilation_config.use_cudagraph and \
"update" in new_code.co_names:
if self.vllm_config.compilation_config.cudagraph_mode != \
CUDAGraphMode.NONE and "update" in new_code.co_names:
import depyf
src = depyf.decompile(new_code)
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa

View File

@ -29,10 +29,10 @@ from typing_extensions import Self, assert_never, runtime_checkable
import vllm.envs as envs
from vllm import version
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType,
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo)
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
PassConfig)
CUDAGraphMode, PassConfig)
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config
@ -388,6 +388,10 @@ class ModelConfig:
interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string. Defaults to False."""
skip_mm_profiling: bool = False
"""When enabled, skips multimodal memory profiling and only profiles with
language backbone model during engine initialization.
"""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
@ -837,7 +841,8 @@ class ModelConfig:
media_io_kwargs=self.media_io_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb,
interleave_mm_strings=self.interleave_mm_strings)
interleave_mm_strings=self.interleave_mm_strings,
skip_mm_profiling=self.skip_mm_profiling)
return None
@ -2511,6 +2516,16 @@ class MultiModalConfig:
Enable fully interleaved support for multimodal prompts.
"""
skip_mm_profiling: bool = False
"""
When enabled, skips multimodal memory profiling and only profiles with
language backbone model during engine initialization.
This reduces engine startup time but shifts the responsibility to users for
estimating the peak memory usage of the activation of multimodal encoder and
embedding cache.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@ -3514,11 +3529,21 @@ class VllmConfig:
else:
self.compilation_config.level = \
CompilationLevel.NO_COMPILATION
else:
# NB: Passing both --enforce-eager and a compilation level
# in V0 means the compilation level wins out.
self.compilation_config.level = CompilationLevel.NO_COMPILATION
# if cudagraph_mode is not explicitly set by users, set default value
if self.compilation_config.cudagraph_mode is None:
if envs.VLLM_USE_V1 and self.compilation_config.level \
== CompilationLevel.PIECEWISE:
self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
@ -3526,12 +3551,13 @@ class VllmConfig:
True
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if envs.VLLM_USE_V1 and self.model_config is not None and \
not self.model_config.enforce_eager:
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
# is set to True, full CUDA graphs will be used.
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
logger.info("Cudagraph is disabled under eager mode")
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
elif envs.VLLM_USE_V1:
self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.set_splitting_ops_for_v1()
self._set_cudagraph_sizes()
@ -3551,12 +3577,6 @@ class VllmConfig:
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION
if self.compilation_config.full_cuda_graph and \
not self.model_config.disable_cascade_attn:
logger.info("full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config:
@ -3597,9 +3617,32 @@ class VllmConfig:
"to True to enable.")
current_platform.check_and_update_config(self)
# final check of cudagraph mode after platform-specific update
if envs.VLLM_USE_V1:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
and self.model_config is not None and \
not self.model_config.disable_cascade_attn:
logger.info("CUDAGraphMode.FULL is not supported with "
"cascade attention currently. Disabling cascade"
"attention.")
self.model_config.disable_cascade_attn = True
if self.compilation_config.cudagraph_mode\
.requires_piecewise_compilation():
assert self.compilation_config.level == \
CompilationLevel.PIECEWISE, \
"Compilation level should be CompilationLevel.PIECEWISE "\
"when cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
if not self.instance_id:
self.instance_id = random_uuid()[:5]
# Do this after all the updates to compilation_config.level
if envs.VLLM_USE_V1 and \
self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
if (envs.VLLM_USE_V1
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
# logger should only print warning message for hybrid models. As we

View File

@ -23,6 +23,7 @@ logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
@ -93,6 +94,15 @@ class CacheConfig:
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""
mamba_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (both the conv as well as the
ssm state). If set to 'auto', the data type will be inferred from the model
config."""
mamba_ssm_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (ssm state only, conv state will
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
for the ssm state will be determined by mamba_cache_dtype."""
# Will be set after profiling.
num_gpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for GPU memory."""
@ -123,6 +133,8 @@ class CacheConfig:
"""
factors: list[Any] = []
factors.append(self.cache_dtype)
factors.append(self.mamba_cache_dtype)
factors.append(self.mamba_ssm_cache_dtype)
# `cpu_offload_gb` does not use `torch.compile` yet.
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()

View File

@ -1,12 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import hashlib
from collections import Counter
from dataclasses import asdict, field
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
from pydantic import TypeAdapter
from pydantic import TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
@ -31,6 +32,40 @@ class CompilationLevel:
PIECEWISE = 3
class CUDAGraphMode(enum.Enum):
""" Constants for the cudagraph mode in CompilationConfig.
Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
treated as concrete runtime mode for cudagraph runtime dispatching.
"""
NONE = 0
PIECEWISE = 1
FULL = 2
FULL_DECODE_ONLY = (FULL, NONE)
FULL_AND_PIECEWISE = (FULL, PIECEWISE)
def decode_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(self.value[0]) if \
self.separate_routine() else self
def mixed_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(self.value[1]) if \
self.separate_routine() else self
def requires_piecewise_compilation(self) -> bool:
return (self.decode_mode() == CUDAGraphMode.PIECEWISE
or self.mixed_mode() == CUDAGraphMode.PIECEWISE)
def max_cudagraph_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(max(
self.value)) if self.separate_routine() else self
def has_full_cudagraphs(self) -> bool:
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
def separate_routine(self) -> bool:
return isinstance(self.value, tuple)
@config
@dataclass
class PassConfig:
@ -91,6 +126,7 @@ class CompilationConfig:
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- CudaGraph capture:
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
- [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
- [`cudagraph_num_of_warmups`]
@ -157,7 +193,7 @@ class CompilationConfig:
By default, all custom ops are enabled when running without Inductor and
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] = field(default_factory=list)
splitting_ops: Optional[list[str]] = None
"""A list of ops to split the full graph into subgraphs, used in piecewise
compilation."""
@ -187,7 +223,43 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation
use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1)
cudagraph_mode: Optional[CUDAGraphMode] = None
"""
The mode of the cudagraph.
- NONE, no cudagraph capture.
- PIECEWISE. (v1 default)
- FULL.
- FULL_DECODE_ONLY.
- FULL_AND_PIECEWISE.
PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
incompatiable ops (i.e. some attention ops) outside the cudagraph
for general flexibility.
This is the default mode.
FULL mode: Capture full cudagraph for all batches. Can be good for small
models or workloads with small prompts; not supported by many backends.
Generally for performance FULL_AND_PIECEWISE is better.
FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
Mixed prefill-decode batches are run without cudagraphs. Can be good for
decode instances in a P/D setup where prefill is not as important so we
can save some memory.
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
piecewise cudagraph for prefill and mixed prefill-decode batches.
This is like the most performant mode for most models.
Currently, the cudagraph mode is only used for the v1 engine.
Note that the cudagraph logic is generally orthogonal to the
compilation logic. While piecewise cudagraphs require piecewise
compilation (level=PIECEWISE and non-empty splitting_ops), full
cudagraphs are supported with and without compilation.
Warning: This flag is new and subject to change in addition
more modes may be added.
"""
use_cudagraph: bool = True
"""Whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
@ -197,8 +269,9 @@ class CompilationConfig:
CompilationLevel.PIECEWISE (aka -O3).
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future."""
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
"""
cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
@ -213,12 +286,17 @@ class CompilationConfig:
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False."""
full_cuda_graph: bool = False
internally managed buffer. Default is False.
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
"""
full_cuda_graph: Optional[bool] = False
"""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
performance benefits for smaller models.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
"""
pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""
@ -253,6 +331,13 @@ class CompilationConfig:
Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1."""
# Attention ops; used for piecewise cudagraphs
_attention_ops: ClassVar[list[str]] = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
]
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@ -297,13 +382,26 @@ class CompilationConfig:
if pass_config_exclude:
exclude["pass_config"] = pass_config_exclude
return TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode()
# The cast to string is necessary because Pydantic is mocked in docs
# builds and sphinx-argparse doesn't know the return type of decode()
return str(
TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode())
__str__ = __repr__
@field_validator("cudagraph_mode", mode="before")
@classmethod
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
"""
enable parse the `cudagraph_mode` enum type from string
"""
if isinstance(value, str):
return CUDAGraphMode[value.upper()]
return value
def __post_init__(self) -> None:
count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all")
@ -341,7 +439,26 @@ class CompilationConfig:
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)
def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]:
# migrate the deprecated flags
if not self.use_cudagraph:
logger.warning("use_cudagraph is deprecated, use "
"cudagraph_mode=NONE instead.")
if self.cudagraph_mode is not None:
raise ValueError(
"use_cudagraph and cudagraph_mode are mutually"
" exclusive, prefer cudagraph_mode since "
"use_cudagraph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.NONE
if self.full_cuda_graph:
logger.warning("full_cuda_graph is deprecated, use "
"cudagraph_mode=FULL instead.")
if self.cudagraph_mode is not None:
raise ValueError("full_cuda_graph and cudagraph_mode are "
"mutually exclusive, prefer cudagraph_mode "
"since full_cuda_graph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.FULL
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@ -414,15 +531,34 @@ class CompilationConfig:
self.max_capture_size] = self.max_capture_size
def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called
if self.splitting_ops and self.full_cuda_graph:
raise ValueError("full_cuda_graph cannot be used together with "
"splitting_ops, as Full CUDA graph will override "
f"the splitting_ops: {self.splitting_ops}")
# NOTE: this function needs to be called only when level is
# CompilationLevel.PIECEWISE
assert self.level == CompilationLevel.PIECEWISE, (
"set_splitting_ops_for_v1 should only be called when "
"level is CompilationLevel.PIECEWISE")
if not self.splitting_ops:
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
]
if self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture the
# full cudagraph outside the fx graph. This reduces some cpu
# overhead when the runtime batch_size is not cudagraph captured.
# see https://github.com/vllm-project/vllm/pull/20059 for details.
self.splitting_ops = self._attention_ops
elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty "
"splitting_ops.")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"When compilation level is piecewise with empty "
"splitting_ops, PIECEWISE cudagraph_mode will be "
"treated as FULL cudagraph_mode. Please ensure you are "
"using attention backends that support cudagraph or set "
"cudagraph_mode to NONE explicitly if encountering "
"any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all(
op in self.splitting_ops for op in self._attention_ops)

View File

@ -325,4 +325,8 @@ class KVConnectorBase_V1(ABC):
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
if cls is KVConnectorBase_V1:
raise TypeError("get_required_kvcache_layout should not be called "
"on the abstract base class")
return None

View File

@ -228,9 +228,10 @@ class MultiConnector(KVConnectorBase_V1):
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
temp_vllm_config.kv_transfer_config = kv_transfer_config
connector_cls = KVConnectorFactory.get_connector_class(
kv_transfer_config)
required_kvcache_layout = (
KVConnectorBase_V1.get_required_kvcache_layout(
temp_vllm_config))
connector_cls.get_required_kvcache_layout(temp_vllm_config))
if required_kvcache_layout is not None:
layouts.add(required_kvcache_layout)

View File

@ -27,12 +27,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
LoRAConfig, MambaDType, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
RunnerOption, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@ -350,6 +350,7 @@ class EngineArgs:
MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
@ -421,6 +422,8 @@ class EngineArgs:
override_attention_dtype: str = ModelConfig.override_attention_dtype
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
additional_config: dict[str, Any] = \
get_field(VllmConfig, "additional_config")
@ -693,6 +696,10 @@ class EngineArgs:
**cache_kwargs["calculate_kv_scales"])
cache_group.add_argument("--kv-sharing-fast-prefill",
**cache_kwargs["kv_sharing_fast_prefill"])
cache_group.add_argument("--mamba-cache-dtype",
**cache_kwargs["mamba_cache_dtype"])
cache_group.add_argument("--mamba-ssm-cache-dtype",
**cache_kwargs["mamba_ssm_cache_dtype"])
# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
@ -716,6 +723,8 @@ class EngineArgs:
multimodal_group.add_argument(
"--interleave-mm-strings",
**multimodal_kwargs["interleave_mm_strings"])
multimodal_group.add_argument("--skip-mm-profiling",
**multimodal_kwargs["skip_mm_profiling"])
# LoRA related configs
lora_kwargs = get_kwargs(LoRAConfig)
@ -918,6 +927,7 @@ class EngineArgs:
limit_mm_per_prompt=self.limit_mm_per_prompt,
interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs,
skip_mm_profiling=self.skip_mm_profiling,
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
@ -1101,6 +1111,8 @@ class EngineArgs:
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
)
ray_runtime_env = None

View File

@ -126,7 +126,7 @@ async def lifespan(app: FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10.)
await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
await engine_client.do_log_stats()
task = asyncio.create_task(_force_log())

View File

@ -38,6 +38,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_PREFIX: str = ""
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None
VLLM_LOG_STATS_INTERVAL: float = 10.
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
@ -122,6 +123,7 @@ if TYPE_CHECKING:
VLLM_MOE_DP_CHUNK_SIZE: int = 256
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_MXFP4_USE_MARLIN: Optional[bool] = None
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
@ -182,6 +184,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
return int(value)
def maybe_convert_bool(value: Optional[str]) -> Optional[bool]:
if value is None:
return None
return bool(int(value))
def get_vllm_port() -> Optional[int]:
"""Get the port from VLLM_PORT environment variable.
@ -429,6 +437,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0"))
if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None,
# If set, vllm will log stats at this interval in seconds
# If not set, vllm will log stats every 10 seconds.
"VLLM_LOG_STATS_INTERVAL":
lambda: val if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10.")))
> 0. else 10.,
# Trace function calls
# If set to 1, vllm will trace function calls
# Useful for debugging
@ -906,6 +920,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MARLIN_USE_ATOMIC_ADD":
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
# Whether to use marlin kernel in mxfp4 quantization method
"VLLM_MXFP4_USE_MARLIN":
lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)),
# Whether to turn on the outlines cache for V0
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
@ -1090,6 +1108,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
# vllm cutlass GEMM, marlin GEMM.
"VLLM_USE_TRTLLM_FP4_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))),
# Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time.
# If set to 1, allows GC to run during capture.
@ -1197,6 +1221,7 @@ def compute_hash() -> str:
"VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE",
"VLLM_FUSED_MOE_CHUNK_SIZE",
"VLLM_USE_TRTLLM_FP4_GEMM",
]
for key in environment_variables_to_hash:
if key in environment_variables:

View File

@ -5,13 +5,13 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
@ -26,6 +26,27 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list)
class BatchDescriptor(NamedTuple):
"""
Batch descriptor for cudagraph dispatching. We should keep the num of
items as minimal as possible to properly and uniquely describe the padded
batch for cudagraph.
"""
num_tokens: int
uniform_decode: bool = False
"""
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
"""
@property
def non_uniform(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
"""
return BatchDescriptor(self.num_tokens, uniform_decode=False)
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
max_num_tokens: int,
chunk_idx: int) -> list[int]:
@ -152,7 +173,15 @@ class ForwardContext:
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
skip_cuda_graphs: bool = False
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
# by default NONE, no cudagraph is used.
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
batch_descriptor: Optional[BatchDescriptor] = None
def __post_init__(self):
assert self.cudagraph_runtime_mode in [
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
_forward_context: Optional[ForwardContext] = None
@ -168,13 +197,13 @@ def get_forward_context() -> ForwardContext:
@contextmanager
def set_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False,
):
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
@ -198,7 +227,8 @@ def set_forward_context(
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
skip_cuda_graphs=skip_cuda_graphs,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
)
try:

View File

@ -18,6 +18,8 @@ from vllm.utils import direct_register_custom_op
def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
bias1: Optional[torch.Tensor],
bias2: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
@ -26,6 +28,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
activation: Optional[str] = "silu",
expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
@ -88,6 +91,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8]
assert topk_weights.dtype == torch.float32
M, K = hidden_states.shape
E = w1.shape[0]
@ -138,6 +142,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
hidden_states,
intermediate_cache1,
w1,
bias1,
w1_scale,
global_scale1,
w1_zeros,
@ -161,8 +166,28 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
use_fp32_reduce=True,
is_zp_float=False)
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
elif activation == "swiglu_oai":
# NOTE: in gpt-oss, the gate_proj and up_proj is interleaved
# - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2]
# - origin: gate, up = gate_up[..., :N], gate_up[..., N:]
@torch.compile(dynamic=True)
def swiglu_oai(gate_up):
alpha = 1.702
limit = 7.0
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
return (up + 1) * glu
intermediate_cache2 = swiglu_oai(intermediate_cache1)
else:
raise ValueError(f"Unsupported activation: {activation}. "
"Only silu and swiglu_oai activations are supported.")
if expert_map is not None:
intermediate_cache3.zero_()
@ -171,6 +196,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
intermediate_cache2,
intermediate_cache3,
w2,
bias2,
w2_scale,
global_scale2,
w2_zeros,

View File

@ -1189,10 +1189,10 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
activation_scale: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
@ -1206,17 +1206,12 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, input_scale = moe_kernel_quantize_input(
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False)
output1_scales_scalar = gemm1_weights_scale * input_scale * (
1.0 / activation_scale)
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
output2_scales_scalar = activation_scale * gemm2_weights_scale
from vllm.utils.flashinfer import (
flashinfer_trtllm_fp8_per_tensor_scale_moe)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
@ -1244,24 +1239,24 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
gemm2_weights: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float = 1.0,
use_routing_scales_on_input: bool = False,
tile_tokens_dim: int = 8,
routing_method_type: int = 0) -> torch.Tensor:
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
pass

View File

@ -36,7 +36,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
has_triton_kernels, is_torch_equal_or_newer, round_up)
round_up)
from vllm.utils.flashinfer import has_flashinfer
if current_platform.is_cuda_alike():
@ -751,19 +751,11 @@ class FusedMoE(CustomOp):
self.global_num_experts = num_experts + num_redundant_experts
# we padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
if not current_platform.is_device_capability(100):
if not is_torch_equal_or_newer("2.8.0"):
raise RuntimeError(
"Mxfp4 on non-blackwell requires torch >= 2.8.0")
if not has_triton_kernels():
raise NotImplementedError(
"triton_kernels must be installed for "
"mxfp4 on non-blackwell")
if (current_platform.is_rocm()
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)
if (quant_config and quant_config.get_name() == "mxfp4"
and (current_platform.is_rocm()
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)):
hidden_size = round_up(hidden_size, 256)
# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config

View File

@ -11,7 +11,7 @@ from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba_attn import (
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)

View File

@ -1,14 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import NamedTuple, Optional
import torch
from torch import nn
from torch.nn.parameter import Parameter
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
@ -19,7 +20,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@ -55,6 +56,8 @@ class MambaMixer(MambaBase, CustomOp):
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
self.time_step_rank = time_step_rank
@ -152,15 +155,42 @@ class MambaMixer(MambaBase, CustomOp):
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def _ssm_transform(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.is_lora_enabled:
# Lora kernel requires contiguous tensor.
ssm_params = self.x_proj(x.contiguous())[0]
else:
ssm_params = self.x_proj(x)[0]
time_step, B, C = torch.split(
ssm_params,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
assert self.c_layernorm is not None
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
return discrete_time_step, B, C
def forward(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1:
return CustomOp.forward(self, hidden_states, mamba_cache_params)
else:
return self.forward_cuda(hidden_states, mamba_cache_params)
return self.forward_cuda(
hidden_states,
mamba_cache_params,
)
def forward_native(self,
hidden_states: torch.Tensor,
@ -170,6 +200,27 @@ class MambaMixer(MambaBase, CustomOp):
def forward_cuda(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
"""
Run the Mamba-1 SSM pipeline.
Steps
-----
1. Apply the gated-MLP linear projection to the raw input.
2. Pass the projected sequence through the convolutional mixing layer.
3. Feed the result into the State-Space Model (SSM) blocks.
4. Perform the recurrence y SSM(A, B, C, Δ)(x)
to produce contextual representations.
5. Project the contextualised sequence back
to the output embedding dimension.
Batch handling
--------------
Prefill and decode tokens are processed by dedicated CUDA
kernels for both the convolutional (conv1d) and SSM stages.
In the case of a mixed batch (containing both prefill and
decode tokens), both sets of kernels are executed independently
and their outputs are concatenated before the final output projection.
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
@ -185,126 +236,151 @@ class MambaMixer(MambaBase, CustomOp):
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_state = mamba1_metadata.has_initial_states
context_lens_tensor = mamba1_metadata.context_lens_tensor
has_initial_states = mamba1_metadata.has_initial_states
else:
assert isinstance(attn_metadata, AttentionMetadata)
assert mamba_cache_params is not None
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
query_start_loc = attn_metadata.query_start_loc
context_lens_tensor = attn_metadata.context_lens_tensor
has_initial_states = None
if context_lens_tensor is not None:
has_initial_state = context_lens_tensor > 0
has_initial_states = context_lens_tensor > 0
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
hidden_states = hidden_states.contiguous()
return self.out_proj(hidden_states.transpose(-2, -1))[0]
hidden_states_BC = hidden_states_BC.contiguous()
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
if query_start_loc is not None and context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
has_prefill = num_prefill_tokens > 0
has_decode = num_decode_tokens > 0
prefill_decode_split = split_batch_to_prefill_and_decode(
hidden_states_BC,
gate,
state_indices_tensor,
query_start_loc,
has_initial_states,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
num_decodes,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
gate_p = prefill_decode_split.gate_p
gate_d = prefill_decode_split.gate_d
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
query_start_loc_p = prefill_decode_split.query_start_loc_p
has_initial_states_p = prefill_decode_split.has_initial_states_p
ssm_outputs = []
if has_prefill:
# 2. Convolution sequence transformation
conv_out_p = causal_conv1d_fn(
hidden_states_BC_p,
conv_weights,
bias=self.conv1d.bias,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_state,
cache_indices=state_indices_tensor,
query_start_loc=query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p)
# 3. State Space Model sequence transformations.
discrete_time_step_p, B_p, C_p = self._ssm_transform(
conv_out_p.transpose(-2, -1))
time_proj_bias = self._time_proj_bias()
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_out_p = selective_scan_fn(
conv_out_p,
ssm_state,
discrete_time_step_p,
self.A,
B_p.transpose(-2, -1),
C_p.transpose(-2, -1),
self.D.float(),
gate_p,
time_proj_bias,
delta_softplus=True,
cache_indices=state_indices_tensor_p,
has_initial_state=has_initial_states_p,
query_start_loc=query_start_loc_p)
ssm_outputs.append(scan_out_p)
if has_decode:
# 2. Convolution sequence transformation
conv_out_d = causal_conv1d_update(
hidden_states_BC_d.transpose(0, 1),
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
conv_state_indices=state_indices_tensor_d).transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
# 3. State Space Model sequence transformation.
discrete_time_step_d, B_d, C_d = self._ssm_transform(
conv_out_d.transpose(-2, -1))
time_proj_bias = self._time_proj_bias()
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
ssm_parameters = self.x_proj(
hidden_states.transpose(-2, -1).contiguous())[0]
else:
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
assert self.c_layernorm is not None
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if query_start_loc is not None and context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=state_indices_tensor,
has_initial_state=has_initial_state,
query_start_loc=query_start_loc)
else:
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_outputs_d = torch.empty_like(
hidden_states_BC_d.transpose(0, 1))
selective_state_update(ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
conv_out_d.transpose(0, 1),
discrete_time_step_d.transpose(0, 1),
self.A,
B,
C,
B_d,
C_d,
self.D,
gate.transpose(0, 1),
gate_d.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)
state_batch_indices=state_indices_tensor_d,
out=scan_outputs_d)
scan_outputs_d = scan_outputs_d.transpose(0, 1)
# 4. Final linear projection
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1).contiguous())[0]
if envs.VLLM_USE_V1:
ssm_outputs.insert(0, scan_outputs_d)
else:
ssm_outputs.append(scan_outputs_d)
scan_outputs_combined = ssm_outputs[0] if len(
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
# 5. Final output projection
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
scan_outputs_combined = scan_outputs_combined.transpose(
-2, -1).contiguous()
out = self.out_proj(scan_outputs_combined)[0]
else:
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
return out
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba1_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba1_state_shape(
@ -317,3 +393,69 @@ class MambaMixer(MambaBase, CustomOp):
@property
def mamba_type(self) -> str:
return "mamba1"
def _time_proj_bias(self) -> Optional[torch.Tensor]:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()
return None
class PrefillDecodeSplit(NamedTuple):
hidden_states_BC_p: torch.Tensor
hidden_states_BC_d: torch.Tensor
gate_p: torch.Tensor
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
query_start_loc_p: Optional[torch.Tensor]
has_initial_states_p: Optional[torch.Tensor]
def split_batch_to_prefill_and_decode(
hidden_states_BC: torch.Tensor,
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
query_start_loc: torch.Tensor,
has_initial_states: Optional[torch.Tensor],
num_prefill_tokens: int,
num_decode_tokens: int,
num_prefills: int,
num_decodes: int,
) -> PrefillDecodeSplit:
if envs.VLLM_USE_V1:
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
gate_d, gate_p = torch.split(gate,
[num_decode_tokens, num_prefill_tokens],
dim=-1)
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor, [num_decodes, num_prefills], dim=0)
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
num_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None
else:
# In v0, prefill tokens come first, then decode tokens.
hidden_states_BC_p, hidden_states_BC_d = torch.split(
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decode_tokens],
dim=-1)
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor, [num_prefills, num_decodes], dim=0)
query_start_loc_p = (query_start_loc[:num_prefills +
1] if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[:num_prefills] if (
has_initial_states is not None and num_prefills > 0) else None
return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,
hidden_states_BC_d=hidden_states_BC_d,
gate_p=gate_p,
gate_d=gate_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
)

View File

@ -8,7 +8,7 @@ from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
@ -36,7 +36,7 @@ from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024
@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp):
**selective** state spaces)
"""
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# For TP, the sharding plan is as follows:
@ -417,6 +417,8 @@ class MambaMixer2(MambaBase, CustomOp):
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def forward_native(
@ -670,7 +672,7 @@ class MambaMixer2(MambaBase, CustomOp):
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim),
)
state_dtype=ssm_state.dtype)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
@ -732,6 +734,15 @@ class MambaMixer2(MambaBase, CustomOp):
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba2_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=self.intermediate_size,

View File

@ -1,6 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import torch
from vllm.config import MambaDType, ModelDType
from vllm.distributed import divide
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
class MambaStateDtypeCalculator:
@classmethod
def linear_attention_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires testing
if mamba_cache_dtype == "float32":
raise ValueError("fp32 state for minimax is not yet supported")
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, )
@classmethod
def mamba1_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires kernel changes
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
raise ValueError("fp32 state for mamba1 is not yet supported")
else:
return MambaStateDtypeCalculator.mamba2_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
@classmethod
def mamba2_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)
if mamba_ssm_cache_dtype == "auto":
temporal_state_dtype = conv_state_dtype
else:
temporal_state_dtype = (
STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])
return (conv_state_dtype, temporal_state_dtype)
class MambaStateShapeCalculator:

View File

@ -41,6 +41,7 @@ def _mamba_chunk_scan_combined_fwd(x,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
out=None):
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
batch, seqlen, nheads, headdim = x.shape
@ -118,7 +119,7 @@ def _mamba_chunk_scan_combined_fwd(x,
if initial_states is not None else None,
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=C.dtype,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
is_cont_batched=cu_seqlens is not None)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])
@ -189,7 +190,8 @@ def mamba_chunk_scan_combined(x,
dt_limit=(0.0, float("inf")),
out=None,
return_final_states=False,
return_varlen_states=False):
return_varlen_states=False,
state_dtype=None):
"""
Argument:
x: (batch, seqlen, nheads, headdim)
@ -206,6 +208,7 @@ def mamba_chunk_scan_combined(x,
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
out: Preallocated output tensor
state_dtype: The data type of the ssm state
"""
if not return_varlen_states:
@ -229,7 +232,8 @@ def mamba_chunk_scan_combined(x,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
out=out)
out=out,
state_dtype=state_dtype)
if not return_varlen_states:
if not return_final_states:
return

Some files were not shown because too many files have changed in this diff Show More