mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 02:35:42 +08:00
Merge branch 'main' into wye-refactor-quant-folder
This commit is contained in:
commit
7e2fb3c507
@ -31,16 +31,6 @@
|
|||||||
steps:
|
steps:
|
||||||
##### fast check tests #####
|
##### fast check tests #####
|
||||||
|
|
||||||
- label: Documentation Build # 2min
|
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
working_dir: "/vllm-workspace/test_docs"
|
|
||||||
fast_check: true
|
|
||||||
no_gpu: True
|
|
||||||
commands:
|
|
||||||
- pip install -r ../requirements/docs.txt
|
|
||||||
# TODO: add `--strict` once warnings in docstrings are fixed
|
|
||||||
- mkdocs build
|
|
||||||
|
|
||||||
- label: Pytorch Nightly Dependency Override Check # 2min
|
- label: Pytorch Nightly Dependency Override Check # 2min
|
||||||
# if this test fails, it means the nightly torch version is not compatible with some
|
# if this test fails, it means the nightly torch version is not compatible with some
|
||||||
# of the dependencies. Please check the error message and add the package to whitelist
|
# of the dependencies. Please check the error message and add the package to whitelist
|
||||||
@ -669,6 +659,7 @@ steps:
|
|||||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||||
|
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||||
# Fusion
|
# Fusion
|
||||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -207,3 +207,6 @@ shellcheck*/
|
|||||||
|
|
||||||
# Ignore moe/marlin_moe gen code
|
# Ignore moe/marlin_moe gen code
|
||||||
csrc/moe/marlin_moe_wna16/kernel_*
|
csrc/moe/marlin_moe_wna16/kernel_*
|
||||||
|
|
||||||
|
# Ignore ep_kernels_workspace folder
|
||||||
|
ep_kernels_workspace/
|
||||||
@ -249,7 +249,6 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||||
"csrc/quantization/activation_kernels.cu"
|
"csrc/quantization/activation_kernels.cu"
|
||||||
"csrc/cuda_utils_kernels.cu"
|
"csrc/cuda_utils_kernels.cu"
|
||||||
"csrc/prepare_inputs/advance_step.cu"
|
|
||||||
"csrc/custom_all_reduce.cu"
|
"csrc/custom_all_reduce.cu"
|
||||||
"csrc/torch_bindings.cpp")
|
"csrc/torch_bindings.cpp")
|
||||||
|
|
||||||
@ -352,6 +351,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||||
|
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||||
|
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||||
|
endif()
|
||||||
|
|
||||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||||
|
|
||||||
@ -365,7 +368,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${MARLIN_SRCS}"
|
SRCS "${MARLIN_SRCS}"
|
||||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||||
|
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||||
|
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||||
|
endif()
|
||||||
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||||
|
|
||||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||||
else()
|
else()
|
||||||
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||||
@ -855,6 +863,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
||||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||||
|
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
|
||||||
|
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||||
|
endif()
|
||||||
|
|
||||||
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
||||||
|
|
||||||
|
|||||||
@ -1,63 +1,199 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from quart import Quart, make_response, request
|
from quart import Quart, Response, make_response, request
|
||||||
|
from rate_limiter import RateLimiter
|
||||||
|
from request_queue import RequestQueue
|
||||||
|
|
||||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
app = Quart(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def forward_request(url, data):
|
def parse_args():
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
"""parse command line arguments"""
|
||||||
|
parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server")
|
||||||
|
|
||||||
|
# Add args
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=float,
|
||||||
|
default=300,
|
||||||
|
help="Timeout for backend service requests in seconds (default: 300)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-concurrent",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Maximum concurrent requests to backend services (default: 100)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--queue-size",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="Maximum number of requests in the queue (default: 500)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rate-limit",
|
||||||
|
type=int,
|
||||||
|
default=40,
|
||||||
|
help="Maximum requests per second (default: 40)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=8000,
|
||||||
|
help="Port to run the server on (default: 8000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill-url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8100/v1/completions",
|
||||||
|
help="Prefill service endpoint URL",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8200/v1/completions",
|
||||||
|
help="Decode service endpoint URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""parse command line arguments"""
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Initialize configuration using command line parameters
|
||||||
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
|
||||||
|
MAX_CONCURRENT_REQUESTS = args.max_concurrent
|
||||||
|
REQUEST_QUEUE_SIZE = args.queue_size
|
||||||
|
RATE_LIMIT = args.rate_limit
|
||||||
|
PREFILL_SERVICE_URL = args.prefill_url
|
||||||
|
DECODE_SERVICE_URL = args.decode_url
|
||||||
|
PORT = args.port
|
||||||
|
|
||||||
|
app = Quart(__name__)
|
||||||
|
|
||||||
|
# Initialize the rate limiter and request queue
|
||||||
|
rate_limiter = RateLimiter(RATE_LIMIT)
|
||||||
|
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
|
||||||
|
|
||||||
|
# Attach the configuration object to the application instance
|
||||||
|
app.config.update(
|
||||||
|
{
|
||||||
|
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
|
||||||
|
"rate_limiter": rate_limiter,
|
||||||
|
"request_queue": request_queue,
|
||||||
|
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
|
||||||
|
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start queue processing on app startup
|
||||||
|
@app.before_serving
|
||||||
|
async def startup():
|
||||||
|
"""Start request processing task when app starts serving"""
|
||||||
|
asyncio.create_task(request_queue.process())
|
||||||
|
|
||||||
|
async def forward_request(url, data):
|
||||||
|
"""Forward request to backend service with rate limiting and error handling"""
|
||||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
async with session.post(url=url, json=data, headers=headers) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
# if response.headers.get('Transfer-Encoding') == 'chunked':
|
|
||||||
if True:
|
|
||||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
|
||||||
yield chunk_bytes
|
|
||||||
else:
|
|
||||||
content = await response.read()
|
|
||||||
yield content
|
|
||||||
|
|
||||||
|
# Use rate limiter as context manager
|
||||||
@app.route("/v1/completions", methods=["POST"])
|
async with (
|
||||||
async def handle_request():
|
rate_limiter,
|
||||||
try:
|
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
||||||
original_request_data = await request.get_json()
|
|
||||||
|
|
||||||
prefill_request = original_request_data.copy()
|
|
||||||
# change max_tokens = 1 to let it only do prefill
|
|
||||||
prefill_request["max_tokens"] = 1
|
|
||||||
|
|
||||||
# finish prefill
|
|
||||||
async for _ in forward_request(
|
|
||||||
"http://localhost:8100/v1/completions", prefill_request
|
|
||||||
):
|
):
|
||||||
continue
|
try:
|
||||||
|
async with session.post(
|
||||||
|
url=url, json=data, headers=headers
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
# Stream response chunks
|
||||||
|
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||||
|
yield chunk_bytes
|
||||||
|
else:
|
||||||
|
# Handle backend service errors
|
||||||
|
error_text = await response.text()
|
||||||
|
logger.error(
|
||||||
|
"Backend service error: %s - %s",
|
||||||
|
response.status,
|
||||||
|
error_text,
|
||||||
|
)
|
||||||
|
yield b'{"error": "Backend service error"}'
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
# Handle connection errors
|
||||||
|
logger.error("Connection error to %s: %s", url, str(e))
|
||||||
|
yield b'{"error": "Service unavailable"}'
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Handle timeout errors
|
||||||
|
logger.error("Timeout connecting to %s", url)
|
||||||
|
yield b'{"error": "Service timeout"}'
|
||||||
|
|
||||||
# return decode
|
async def process_request():
|
||||||
generator = forward_request(
|
"""Process a single request through prefill and decode stages"""
|
||||||
"http://localhost:8200/v1/completions", original_request_data
|
try:
|
||||||
)
|
original_request_data = await request.get_json()
|
||||||
response = await make_response(generator)
|
|
||||||
response.timeout = None
|
|
||||||
|
|
||||||
return response
|
# Create prefill request (max_tokens=1)
|
||||||
|
prefill_request = original_request_data.copy()
|
||||||
|
prefill_request["max_tokens"] = 1
|
||||||
|
|
||||||
except Exception as e:
|
# Execute prefill stage
|
||||||
import sys
|
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request):
|
||||||
import traceback
|
continue
|
||||||
|
|
||||||
exc_info = sys.exc_info()
|
# Execute decode stage and stream response
|
||||||
print("Error occurred in disagg prefill proxy server")
|
generator = forward_request(DECODE_SERVICE_URL, original_request_data)
|
||||||
print(e)
|
response = await make_response(generator)
|
||||||
print("".join(traceback.format_exception(*exc_info)))
|
response.timeout = None # Disable timeout for streaming response
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing request")
|
||||||
|
return Response(
|
||||||
|
response=b'{"error": "Internal server error"}',
|
||||||
|
status=500,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.route("/v1/completions", methods=["POST"])
|
||||||
|
async def handle_request():
|
||||||
|
"""Handle incoming API requests with concurrency and rate limiting"""
|
||||||
|
# Create task for request processing
|
||||||
|
task = asyncio.create_task(process_request())
|
||||||
|
|
||||||
|
# Enqueue request or reject if queue is full
|
||||||
|
if not await request_queue.enqueue(task):
|
||||||
|
return Response(
|
||||||
|
response=b'{"error": "Server busy, try again later"}',
|
||||||
|
status=503,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Return the response from the processing task
|
||||||
|
return await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Handle task cancellation (timeout or queue full)
|
||||||
|
logger.warning("Request cancelled due to timeout or queue full")
|
||||||
|
return Response(
|
||||||
|
response=b'{"error": "Request cancelled"}',
|
||||||
|
status=503,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the Quart server with host can be set to 0.0.0.0
|
||||||
|
app.run(port=PORT)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app.run(port=8000)
|
main()
|
||||||
|
|||||||
45
benchmarks/disagg_benchmarks/rate_limiter.py
Normal file
45
benchmarks/disagg_benchmarks/rate_limiter.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Token bucket rate limiter implementation"""
|
||||||
|
|
||||||
|
def __init__(self, rate_limit):
|
||||||
|
self.rate_limit = rate_limit # Requests per second
|
||||||
|
self.num_available_tokens = rate_limit # Available tokens
|
||||||
|
self.last_refill = time.monotonic() # Last token refill time
|
||||||
|
self.lock = asyncio.Lock() # Synchronization lock
|
||||||
|
|
||||||
|
async def acquire(self):
|
||||||
|
"""Acquire a token from the rate limiter"""
|
||||||
|
while True:
|
||||||
|
async with self.lock:
|
||||||
|
current_time = time.monotonic()
|
||||||
|
elapsed = current_time - self.last_refill
|
||||||
|
|
||||||
|
# Refill num_available_tokens if more than 1 second has passed
|
||||||
|
if elapsed > 1.0:
|
||||||
|
self.num_available_tokens = self.rate_limit
|
||||||
|
self.last_refill = current_time
|
||||||
|
|
||||||
|
# Check if num_available_tokens are available
|
||||||
|
if self.num_available_tokens > 0:
|
||||||
|
self.num_available_tokens -= 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Calculate wait time if no num_available_tokens available
|
||||||
|
wait_time = 1.0 - elapsed
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Enter async context manager - acquire token"""
|
||||||
|
await self.acquire()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
"""Exit async context manager - no cleanup needed"""
|
||||||
|
pass
|
||||||
39
benchmarks/disagg_benchmarks/request_queue.py
Normal file
39
benchmarks/disagg_benchmarks/request_queue.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class RequestQueue:
|
||||||
|
"""Request queue manager with concurrency control"""
|
||||||
|
|
||||||
|
def __init__(self, max_concurrent, max_queue_size):
|
||||||
|
# Maximum concurrent requests
|
||||||
|
self.max_concurrent = max_concurrent
|
||||||
|
self.max_queue_size = max_queue_size # Maximum queue size
|
||||||
|
# Concurrency control
|
||||||
|
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
self.queue = deque() # Request queue
|
||||||
|
self.queue_size = 0 # Current queue size
|
||||||
|
self.lock = asyncio.Lock() # Sync queue Lock
|
||||||
|
|
||||||
|
async def enqueue(self, task):
|
||||||
|
"""Add a request task to the queue"""
|
||||||
|
async with self.lock:
|
||||||
|
if self.queue_size >= self.max_queue_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.queue.append(task)
|
||||||
|
self.queue_size += 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def process(self):
|
||||||
|
"""Process queued requests using semaphore for concurrency control"""
|
||||||
|
while True:
|
||||||
|
if self.queue:
|
||||||
|
async with self.semaphore, self.lock:
|
||||||
|
task = self.queue.popleft()
|
||||||
|
self.queue_size -= 1
|
||||||
|
await task
|
||||||
|
await asyncio.sleep(0.01) # Yield control to event loop
|
||||||
@ -236,6 +236,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
|||||||
a=bt.a,
|
a=bt.a,
|
||||||
c=None,
|
c=None,
|
||||||
b_q_weight=w_q,
|
b_q_weight=w_q,
|
||||||
|
b_bias=None,
|
||||||
b_scales=w_s,
|
b_scales=w_s,
|
||||||
global_scale=None,
|
global_scale=None,
|
||||||
b_zeros=w_zp,
|
b_zeros=w_zp,
|
||||||
|
|||||||
@ -321,6 +321,8 @@ static inline constexpr auto kFE3M2f =
|
|||||||
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||||
static inline constexpr auto kFE4M3fn =
|
static inline constexpr auto kFE4M3fn =
|
||||||
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||||
|
static inline constexpr auto kFE8M0fnu =
|
||||||
|
ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||||
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
||||||
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
||||||
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
||||||
|
|||||||
@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME {
|
|||||||
TEMPLATE = ("template __global__ void Marlin<"
|
TEMPLATE = ("template __global__ void Marlin<"
|
||||||
"{{scalar_t}}, "
|
"{{scalar_t}}, "
|
||||||
"{{w_type_id}}, "
|
"{{w_type_id}}, "
|
||||||
|
"{{s_type_id}}, "
|
||||||
"{{threads}}, "
|
"{{threads}}, "
|
||||||
"{{thread_m_blocks}}, "
|
"{{thread_m_blocks}}, "
|
||||||
"{{thread_n_blocks}}, "
|
"{{thread_n_blocks}}, "
|
||||||
@ -77,6 +78,7 @@ def generate_new_kernels():
|
|||||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||||
continue
|
continue
|
||||||
# nvfp4 only supports group_size == 16
|
# nvfp4 only supports group_size == 16
|
||||||
|
# mxfp4 only supports group_size == 32
|
||||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||||
continue
|
continue
|
||||||
# other quantization methods don't support group_size = 16
|
# other quantization methods don't support group_size = 16
|
||||||
@ -89,9 +91,22 @@ def generate_new_kernels():
|
|||||||
|
|
||||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||||
|
|
||||||
|
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
|
||||||
|
s_type = "vllm::kFE4M3fn"
|
||||||
|
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
|
||||||
|
s_type = "vllm::kFE8M0fnu"
|
||||||
|
if dtype == "fp16":
|
||||||
|
# we cannot safely dequantize e8m0 to fp16, so skip this
|
||||||
|
continue
|
||||||
|
elif dtype == "fp16":
|
||||||
|
s_type = "vllm::kFloat16"
|
||||||
|
elif dtype == "bf16":
|
||||||
|
s_type = "vllm::kBFloat16"
|
||||||
|
|
||||||
template_str = jinja2.Template(TEMPLATE).render(
|
template_str = jinja2.Template(TEMPLATE).render(
|
||||||
scalar_t=c_dtype,
|
scalar_t=c_dtype,
|
||||||
w_type_id=scalar_type + ".id()",
|
w_type_id=scalar_type + ".id()",
|
||||||
|
s_type_id=s_type + ".id()",
|
||||||
threads=threads,
|
threads=threads,
|
||||||
thread_m_blocks=max(m_blocks, 1),
|
thread_m_blocks=max(m_blocks, 1),
|
||||||
thread_n_blocks=n_blocks,
|
thread_n_blocks=n_blocks,
|
||||||
|
|||||||
@ -7,23 +7,25 @@
|
|||||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||||
#include "core/scalar_type.hpp"
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
#define MARLIN_KERNEL_PARAMS \
|
#define MARLIN_KERNEL_PARAMS \
|
||||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||||
const int4 *__restrict__ scales_ptr, \
|
const int4 *__restrict__ b_bias_ptr, \
|
||||||
const uint16_t *__restrict__ scale2_ptr, \
|
const int4 *__restrict__ scales_ptr, \
|
||||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
const uint16_t *__restrict__ scale2_ptr, \
|
||||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||||
const int32_t *__restrict__ expert_ids_ptr, \
|
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
const int32_t *__restrict__ expert_ids_ptr, \
|
||||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||||
|
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
|
||||||
bool use_fp32_reduce, int max_shared_mem
|
bool use_fp32_reduce, int max_shared_mem
|
||||||
|
|
||||||
namespace MARLIN_NAMESPACE_NAME {
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
|
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
|
|||||||
@ -280,6 +280,7 @@ __device__ inline void wait_negative_and_add(int* lock) {
|
|||||||
|
|
||||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
|
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
@ -299,6 +300,7 @@ __global__ void Marlin(
|
|||||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||||
|
const int4* __restrict__ b_bias_ptr,
|
||||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||||
// (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||||
@ -318,8 +320,9 @@ __global__ void Marlin(
|
|||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
int* locks, // extra global storage for barrier synchronization
|
int* locks, // extra global storage for barrier synchronization
|
||||||
bool use_atomic_add, // whether to use atomic add to reduce
|
bool has_bias,
|
||||||
bool use_fp32_reduce, // whether to use fp32 global reduce
|
bool use_atomic_add, // whether to use atomic add to reduce
|
||||||
|
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||||
int max_shared_mem) {
|
int max_shared_mem) {
|
||||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||||
// same size, which might involve multiple column "slices" (of width 16 *
|
// same size, which might involve multiple column "slices" (of width 16 *
|
||||||
@ -342,12 +345,23 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
extern __shared__ int4 sh[];
|
extern __shared__ int4 sh[];
|
||||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||||
|
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
|
||||||
|
if constexpr (w_type == vllm::kFE2M1f) {
|
||||||
|
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||||
|
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||||
|
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||||
|
static_assert(s_type == vllm::kBFloat16);
|
||||||
|
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||||
|
static_assert(s_type == vllm::kFloat16);
|
||||||
|
}
|
||||||
|
|
||||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||||
// see comments of dequant.h for more details
|
// see comments of dequant.h for more details
|
||||||
constexpr bool dequant_skip_flop =
|
constexpr bool dequant_skip_flop =
|
||||||
!is_int_type ||
|
w_type == vllm::kFE4M3fn ||
|
||||||
|
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||||
|
|
||||||
@ -365,6 +379,7 @@ __global__ void Marlin(
|
|||||||
const int zp_expert_stride =
|
const int zp_expert_stride =
|
||||||
is_zp_float ? prob_n * prob_k / group_size / 8
|
is_zp_float ? prob_n * prob_k / group_size / 8
|
||||||
: prob_n * prob_k / group_size / (pack_factor * 4);
|
: prob_n * prob_k / group_size / (pack_factor * 4);
|
||||||
|
const int b_bias_expert_stride = prob_n / 8;
|
||||||
|
|
||||||
// parallel: num valid moe blocks
|
// parallel: num valid moe blocks
|
||||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||||
@ -475,7 +490,7 @@ __global__ void Marlin(
|
|||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
int idx = tid4 * 4 + i;
|
int idx = tid4 * 4 + i;
|
||||||
idx = idx < block_num_valid_tokens ? idx : 0;
|
idx = idx < block_num_valid_tokens ? idx : 0;
|
||||||
if constexpr (w_type == vllm::kFE2M1f) {
|
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||||
sh_block_topk_weights[idx] = __hmul2(
|
sh_block_topk_weights[idx] = __hmul2(
|
||||||
global_scale, Dtype::num2num2(Dtype::float2num(
|
global_scale, Dtype::num2num2(Dtype::float2num(
|
||||||
topk_weights_ptr[sh_block_sorted_ids[idx]])));
|
topk_weights_ptr[sh_block_sorted_ids[idx]])));
|
||||||
@ -513,7 +528,7 @@ __global__ void Marlin(
|
|||||||
expert_id = expert_ids_ptr[block_id];
|
expert_id = expert_ids_ptr[block_id];
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (w_type == vllm::kFE2M1f) {
|
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||||
uint16_t val = scale2_ptr[expert_id];
|
uint16_t val = scale2_ptr[expert_id];
|
||||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||||
}
|
}
|
||||||
@ -526,6 +541,9 @@ __global__ void Marlin(
|
|||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
g_idx += (expert_id - old_expert_id) * prob_k;
|
g_idx += (expert_id - old_expert_id) * prob_k;
|
||||||
}
|
}
|
||||||
|
if (has_bias) {
|
||||||
|
b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride;
|
||||||
|
}
|
||||||
|
|
||||||
read_moe_block_data(block_id);
|
read_moe_block_data(block_id);
|
||||||
};
|
};
|
||||||
@ -721,7 +739,7 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
(threadIdx.x % 32) / 4;
|
(threadIdx.x % 32) / 4;
|
||||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
|
||||||
|
|
||||||
} else if constexpr (group_blocks != -1)
|
} else if constexpr (group_blocks != -1)
|
||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
@ -734,6 +752,18 @@ __global__ void Marlin(
|
|||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
(threadIdx.x % 32) % 4;
|
(threadIdx.x % 32) % 4;
|
||||||
|
|
||||||
|
int bias_sh_rd;
|
||||||
|
if constexpr (m_block_size_8) {
|
||||||
|
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
|
(threadIdx.x % 32) / 8;
|
||||||
|
} else {
|
||||||
|
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
|
(threadIdx.x % 32) % 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
int bias_sh_wr = threadIdx.x;
|
||||||
|
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||||
|
|
||||||
// Zero-points have the same read layout as the scales
|
// Zero-points have the same read layout as the scales
|
||||||
// (without column-wise case)
|
// (without column-wise case)
|
||||||
constexpr int num_col_threads = 8;
|
constexpr int num_col_threads = 8;
|
||||||
@ -793,7 +823,19 @@ __global__ void Marlin(
|
|||||||
constexpr int sh_b_size = stages * b_sh_stage;
|
constexpr int sh_b_size = stages * b_sh_stage;
|
||||||
int4* sh_b = sh_new;
|
int4* sh_b = sh_new;
|
||||||
int4* sh_red = sh_new;
|
int4* sh_red = sh_new;
|
||||||
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
|
||||||
|
constexpr int sh_size_b_red_min =
|
||||||
|
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
|
||||||
|
constexpr int sh_size_b_red_max =
|
||||||
|
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||||
|
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
|
||||||
|
constexpr int sh_b_red_bias_size =
|
||||||
|
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
|
||||||
|
? sh_size_b_red_max
|
||||||
|
: (sh_size_b_red_min + sh_bias_size);
|
||||||
|
|
||||||
|
int4* sh_bias = sh_new + sh_size_b_red_min;
|
||||||
|
int4* sh_g_idx = sh_new + sh_b_red_bias_size;
|
||||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||||
: (stages * s_sh_stage);
|
: (stages * s_sh_stage);
|
||||||
@ -803,9 +845,9 @@ __global__ void Marlin(
|
|||||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||||
stages * b_sh_stage);
|
stages * b_sh_stage);
|
||||||
int4* sh_a = sh_s + sh_s_size;
|
int4* sh_a = sh_s + sh_s_size;
|
||||||
constexpr int shm_size_used =
|
constexpr int shm_size_used = moe_block_size +
|
||||||
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
|
stages * (g_idx_stage + zp_sh_stage) +
|
||||||
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
sh_s_size + sh_b_red_bias_size;
|
||||||
|
|
||||||
// all remaining shared memory is used to cache A (input)
|
// all remaining shared memory is used to cache A (input)
|
||||||
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
|
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
|
||||||
@ -816,7 +858,8 @@ __global__ void Marlin(
|
|||||||
FragA frag_a[2][thread_m_blocks];
|
FragA frag_a[2][thread_m_blocks];
|
||||||
I4 frag_b_quant[2][b_thread_vecs];
|
I4 frag_b_quant[2][b_thread_vecs];
|
||||||
FragC frag_c[thread_m_blocks][4][2];
|
FragC frag_c[thread_m_blocks][4][2];
|
||||||
FragS frag_s[2][4]; // No act-order
|
FragS frag_s[2][4]; // No act-order
|
||||||
|
FragS frag_bias[2][4];
|
||||||
FragS act_frag_s[2][4][4]; // For act-order
|
FragS act_frag_s[2][4][4]; // For act-order
|
||||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||||
FragZP frag_zp; // Zero-points in fp16
|
FragZP frag_zp; // Zero-points in fp16
|
||||||
@ -1065,10 +1108,15 @@ __global__ void Marlin(
|
|||||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||||
} else {
|
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
|
||||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||||
reinterpret_cast<int2*>(
|
reinterpret_cast<int2*>(
|
||||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||||
|
} else {
|
||||||
|
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||||
|
reinterpret_cast<int2*>(
|
||||||
|
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
|
||||||
|
k % 2];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1281,9 +1329,9 @@ __global__ void Marlin(
|
|||||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||||
|
|
||||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||||
dequant_fp8_scales<scalar_t2>(
|
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1566,7 +1614,7 @@ __global__ void Marlin(
|
|||||||
// Write out the reduce final result in the correct layout. We only actually
|
// Write out the reduce final result in the correct layout. We only actually
|
||||||
// reshuffle matrix fragments in this step, the reduction above is performed
|
// reshuffle matrix fragments in this step, the reduction above is performed
|
||||||
// in fragment layout.
|
// in fragment layout.
|
||||||
auto write_result = [&]() {
|
auto write_result = [&](bool last) {
|
||||||
int c_gl_stride = prob_n / 8;
|
int c_gl_stride = prob_n / 8;
|
||||||
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
||||||
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
||||||
@ -1592,7 +1640,7 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
// We first reorder in shared memory to guarantee the most efficient final
|
// We first reorder in shared memory to guarantee the most efficient final
|
||||||
// global write patterns
|
// global write patterns
|
||||||
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
||||||
scalar_t2 res =
|
scalar_t2 res =
|
||||||
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
||||||
|
|
||||||
@ -1601,14 +1649,27 @@ __global__ void Marlin(
|
|||||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||||
w_type.size_bits() == 4 &&
|
w_type.size_bits() == 4 &&
|
||||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||||
res = __hmul2(res, s[0]);
|
scalar_t2 tmp_scale = s[0];
|
||||||
|
if constexpr (m_block_size_8) {
|
||||||
|
tmp_scale = Dtype::num2num2(
|
||||||
|
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
|
||||||
|
}
|
||||||
|
res = __hmul2(res, tmp_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (w_type == vllm::kFE2M1f) {
|
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||||
if (!mul_topk_weights) {
|
if (!mul_topk_weights) {
|
||||||
res = __hmul2(res, global_scale);
|
res = __hmul2(res, global_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (has_bias && last) {
|
||||||
|
scalar_t2 tmp_bias = b_bias[0];
|
||||||
|
if constexpr (m_block_size_8) {
|
||||||
|
tmp_bias = Dtype::num2num2(
|
||||||
|
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
|
||||||
|
}
|
||||||
|
res = __hadd2(res, tmp_bias);
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (m_block_size_8) {
|
if constexpr (m_block_size_8) {
|
||||||
((scalar_t*)sh_red)[idx] = res.x;
|
((scalar_t*)sh_red)[idx] = res.x;
|
||||||
@ -1626,19 +1687,25 @@ __global__ void Marlin(
|
|||||||
if constexpr (m_block_size_8) {
|
if constexpr (m_block_size_8) {
|
||||||
int wr = c_sh_wr + 16 * j;
|
int wr = c_sh_wr + 16 * j;
|
||||||
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
|
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
|
||||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
frag_s[j / 2][2 * (j % 2) + 0],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||||
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
|
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
|
||||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
frag_s[j / 2][2 * (j % 2) + 1],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||||
} else {
|
} else {
|
||||||
int wr = c_sh_wr + 8 * j;
|
int wr = c_sh_wr + 8 * j;
|
||||||
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
||||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||||
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
||||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||||
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
||||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||||
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
||||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c_sh_wr += 16 * (4 * c_sh_stride);
|
c_sh_wr += 16 * (4 * c_sh_stride);
|
||||||
@ -1805,6 +1872,14 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
|
|
||||||
thread_block_reduce();
|
thread_block_reduce();
|
||||||
|
|
||||||
|
if (has_bias && last) {
|
||||||
|
__syncthreads();
|
||||||
|
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
|
||||||
|
threadIdx.x < 16 * thread_n_blocks / 8);
|
||||||
|
cp_async_fence();
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||||
@ -1867,11 +1942,20 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
barrier_release(&locks[locks_off], last);
|
barrier_release(&locks[locks_off], last);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (has_bias && last) {
|
||||||
|
cp_async_wait<0>();
|
||||||
|
__syncthreads();
|
||||||
|
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
|
||||||
|
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
|
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
|
||||||
wait_negative_and_add(&locks[locks_off]);
|
wait_negative_and_add(&locks[locks_off]);
|
||||||
if (last || use_atomic_add)
|
if (last || use_atomic_add)
|
||||||
// only the last block in a slice actually writes the result
|
// only the last block in a slice actually writes the result
|
||||||
write_result();
|
write_result(last);
|
||||||
int old_slice_row = slice_row;
|
int old_slice_row = slice_row;
|
||||||
slice_row = 0;
|
slice_row = 0;
|
||||||
slice_col_par++;
|
slice_col_par++;
|
||||||
@ -1904,6 +1988,7 @@ __global__ void Marlin(
|
|||||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||||
// Update slice k/n for scales loading
|
// Update slice k/n for scales loading
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
slice_k_start = tb_k * slice_row;
|
slice_k_start = tb_k * slice_row;
|
||||||
|
|||||||
@ -51,8 +51,9 @@ __global__ void permute_cols_kernel(
|
|||||||
} // namespace marlin
|
} // namespace marlin
|
||||||
|
|
||||||
torch::Tensor moe_wna16_marlin_gemm(
|
torch::Tensor moe_wna16_marlin_gemm(
|
||||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
torch::Tensor& b_q_weight,
|
||||||
|
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||||
@ -212,7 +213,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
|||||||
// Get B size
|
// Get B size
|
||||||
int tb_k = th_config.thread_k;
|
int tb_k = th_config.thread_k;
|
||||||
int tb_n = th_config.thread_n;
|
int tb_n = th_config.thread_n;
|
||||||
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
|
int tb_m = thread_m_blocks * 16;
|
||||||
|
|
||||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||||
@ -220,6 +221,11 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
|||||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||||
|
int sh_bias_size = tb_n * 2;
|
||||||
|
int tmp_size =
|
||||||
|
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
|
||||||
|
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
|
||||||
|
|
||||||
int sh_s_size =
|
int sh_s_size =
|
||||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||||
group_size, has_act_order, is_k_full);
|
group_size, has_act_order, is_k_full);
|
||||||
@ -234,8 +240,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
|||||||
sh_zp_size = sh_s_size / 2;
|
sh_zp_size = sh_s_size / 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
|
int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size +
|
||||||
sh_zp_size + sh_g_idx_size + sh_block_meta_size;
|
sh_g_idx_size + sh_block_meta_size;
|
||||||
|
|
||||||
return total_size;
|
return total_size;
|
||||||
}
|
}
|
||||||
@ -270,20 +276,25 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
|||||||
int cache_size = get_kernel_cache_size(
|
int cache_size = get_kernel_cache_size(
|
||||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||||
return cache_size <= max_shared_mem;
|
return cache_size + 512 <= max_shared_mem;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||||
is_zp_float == IS_ZP_FLOAT) { \
|
is_zp_float == IS_ZP_FLOAT) { \
|
||||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
constexpr auto S_TYPE = \
|
||||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
W_TYPE == vllm::kFE2M1f \
|
||||||
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
|
||||||
|
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
|
||||||
|
: vllm::kBFloat16); \
|
||||||
|
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
|
||||||
|
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||||
|
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||||
}
|
}
|
||||||
|
|
||||||
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||||
@ -335,31 +346,45 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
|||||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||||
\
|
|
||||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||||
|
|
||||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
|
||||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
|
||||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
|
||||||
|
|
||||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
|
||||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
|
||||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
|
||||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
|
||||||
|
|
||||||
#define FP4_GET_IF(W_TYPE) \
|
|
||||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
|
||||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
|
||||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
|
||||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
|
||||||
|
|
||||||
#define BIGGROUP_GET_IF(W_TYPE) \
|
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||||
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
|
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||||
|
|
||||||
|
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define NVFP4_GET_IF(W_TYPE) \
|
||||||
|
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||||
|
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||||
|
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||||
|
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||||
|
|
||||||
|
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define MXFP4_GET_IF(W_TYPE) \
|
||||||
|
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||||
|
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||||
|
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||||
|
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||||
|
|
||||||
// We currently have 4-bit models only with group_blocks == 4
|
// We currently have 4-bit models only with group_blocks == 4
|
||||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
||||||
@ -408,12 +433,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
|||||||
COMMON_GET_IF(vllm::kU4B8)
|
COMMON_GET_IF(vllm::kU4B8)
|
||||||
COMMON_GET_IF(vllm::kU8B128)
|
COMMON_GET_IF(vllm::kU8B128)
|
||||||
|
|
||||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
NVFP4_GET_IF(vllm::kFE2M1f)
|
||||||
|
|
||||||
FP4_GET_IF(vllm::kFE2M1f)
|
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||||
|
|
||||||
ACT_GET_IF(vllm::kU4B8)
|
ACT_GET_IF(vllm::kU4B8)
|
||||||
ACT_GET_IF(vllm::kU8B128)
|
ACT_GET_IF(vllm::kU8B128)
|
||||||
|
if (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||||
|
if (false) {
|
||||||
|
}
|
||||||
|
MXFP4_GET_IF(vllm::kFE2M1f)
|
||||||
|
}
|
||||||
|
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
@ -482,16 +512,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
void* s, void* s2, void* zp, void* g_idx, void* perm,
|
||||||
void* sorted_token_ids, void* expert_ids,
|
void* a_tmp, void* sorted_token_ids, void* expert_ids,
|
||||||
void* num_tokens_past_padded, void* topk_weights,
|
void* num_tokens_past_padded, void* topk_weights,
|
||||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||||
vllm::ScalarType const& q_type, bool has_act_order,
|
vllm::ScalarType const& q_type, bool has_bias,
|
||||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
int group_size, int dev, cudaStream_t stream, int thread_k,
|
||||||
int sms, bool use_atomic_add, bool use_fp32_reduce,
|
int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||||
bool is_zp_float) {
|
bool is_zp_float) {
|
||||||
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||||
bool m_block_size_8 = moe_block_size == 8;
|
bool m_block_size_8 = moe_block_size == 8;
|
||||||
@ -538,6 +568,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
const int4* B_ptr = (const int4*)B;
|
const int4* B_ptr = (const int4*)B;
|
||||||
int4* C_ptr = (int4*)C;
|
int4* C_ptr = (int4*)C;
|
||||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||||
|
const int4* bias_ptr = (const int4*)b_bias;
|
||||||
const int4* s_ptr = (const int4*)s;
|
const int4* s_ptr = (const int4*)s;
|
||||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||||
const int4* zp_ptr = (const int4*)zp;
|
const int4* zp_ptr = (const int4*)zp;
|
||||||
@ -648,10 +679,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
// avoid ">>>" being formatted to "> > >"
|
// avoid ">>>" being formatted to "> > >"
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
|
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
|
||||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -659,7 +690,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
|
|
||||||
torch::Tensor moe_wna16_marlin_gemm(
|
torch::Tensor moe_wna16_marlin_gemm(
|
||||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
torch::Tensor& b_q_weight,
|
||||||
|
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||||
@ -766,7 +798,6 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
num_groups = b_scales.size(1);
|
num_groups = b_scales.size(1);
|
||||||
|
|
||||||
torch::Tensor g_idx, perm, a_tmp;
|
torch::Tensor g_idx, perm, a_tmp;
|
||||||
;
|
|
||||||
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
||||||
g_idx = g_idx_or_none.value();
|
g_idx = g_idx_or_none.value();
|
||||||
perm = perm_or_none.value();
|
perm = perm_or_none.value();
|
||||||
@ -815,12 +846,24 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
torch::Tensor global_scale;
|
torch::Tensor global_scale;
|
||||||
if (global_scale_or_none.has_value()) {
|
if (global_scale_or_none.has_value()) {
|
||||||
global_scale = global_scale_or_none.value();
|
global_scale = global_scale_or_none.value();
|
||||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
|
||||||
"global_scale can only be used for float4_e2m1f.");
|
"global_scale can only be used for nvfp4 format.");
|
||||||
} else {
|
} else {
|
||||||
global_scale = torch::empty({0}, options);
|
global_scale = torch::empty({0}, options);
|
||||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
|
||||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
"the global_scale parameter must be passed for nvfp4 format.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_bias = b_bias_or_none.has_value();
|
||||||
|
torch::Tensor b_bias;
|
||||||
|
if (has_bias) {
|
||||||
|
b_bias = b_bias_or_none.value();
|
||||||
|
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
|
||||||
|
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
|
||||||
|
TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n");
|
||||||
|
TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1");
|
||||||
|
} else {
|
||||||
|
b_bias = torch::empty({0}, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor b_zeros;
|
torch::Tensor b_zeros;
|
||||||
@ -832,7 +875,6 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
b_zeros = torch::empty({0}, options);
|
b_zeros = torch::empty({0}, options);
|
||||||
}
|
}
|
||||||
bool has_zp = b_zeros.size(-1) > 0;
|
bool has_zp = b_zeros.size(-1) > 0;
|
||||||
|
|
||||||
if (has_zp) {
|
if (has_zp) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||||
@ -890,41 +932,58 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
if (a.scalar_type() == at::ScalarType::Half) {
|
if (a.scalar_type() == at::ScalarType::Half) {
|
||||||
void* scales_ptr;
|
void* scales_ptr;
|
||||||
if (b_q_type == vllm::kFE2M1f) {
|
if (b_q_type == vllm::kFE2M1f) {
|
||||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
if (group_size == 16)
|
||||||
|
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||||
|
else if (group_size == 32)
|
||||||
|
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||||
|
else
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||||
|
"and group_size == 32 (MXFP4)");
|
||||||
} else {
|
} else {
|
||||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
|
||||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||||
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
perm.data_ptr(), a_tmp.data_ptr<at::Half>(),
|
||||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||||
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
|
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||||
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
|
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||||
is_k_full, has_zp, num_groups, group_size, dev,
|
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
|
||||||
|
has_zp, num_groups, group_size, dev,
|
||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
void* scales_ptr;
|
void* scales_ptr;
|
||||||
if (b_q_type == vllm::kFE2M1f) {
|
if (b_q_type == vllm::kFE2M1f) {
|
||||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
if (group_size == 16)
|
||||||
|
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||||
|
else if (group_size == 32)
|
||||||
|
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||||
|
else
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||||
|
"and group_size == 32 (MXFP4)");
|
||||||
} else {
|
} else {
|
||||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||||
|
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
|
||||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
|
||||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
has_zp, num_groups, group_size, dev,
|
||||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
|
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false,
|
TORCH_CHECK(false,
|
||||||
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
|
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
|
||||||
|
|||||||
@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
|
"Tensor! b_q_weight, Tensor? b_bias_or_none,"
|
||||||
|
"Tensor! b_scales, Tensor? global_scale, Tensor? "
|
||||||
"b_zeros_or_none,"
|
"b_zeros_or_none,"
|
||||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||||
"Tensor sorted_token_ids,"
|
"Tensor sorted_token_ids,"
|
||||||
|
|||||||
16
csrc/ops.h
16
csrc/ops.h
@ -145,22 +145,6 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
|||||||
|
|
||||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
|
||||||
int64_t block_size, torch::Tensor& input_tokens,
|
|
||||||
torch::Tensor& sampled_token_ids,
|
|
||||||
torch::Tensor& input_positions,
|
|
||||||
torch::Tensor& seq_lens,
|
|
||||||
torch::Tensor& slot_mapping,
|
|
||||||
torch::Tensor& block_tables);
|
|
||||||
|
|
||||||
void advance_step_flashinfer(
|
|
||||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
|
||||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
|
||||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
|
||||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
|
||||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
|
||||||
|
|
||||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||||
torch::Tensor const& q_pe,
|
torch::Tensor const& q_pe,
|
||||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||||
|
|||||||
@ -1,336 +0,0 @@
|
|||||||
/*
|
|
||||||
* The goal of this GPU kernel is to advance input tensors on the GPU directly
|
|
||||||
* PR: https://github.com/vllm-project/vllm/pull/6338
|
|
||||||
* Current restrictions:
|
|
||||||
* 1. Specialized for DraftModelRunner
|
|
||||||
* 2. Supports flash_attn only
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "advance_step.cuh"
|
|
||||||
|
|
||||||
namespace prepare_inputs {
|
|
||||||
|
|
||||||
//
|
|
||||||
template <int const num_threads>
|
|
||||||
__global__ void advance_step_flashattn_kernel(
|
|
||||||
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
|
|
||||||
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
|
||||||
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
|
||||||
int64_t const block_tables_stride) {
|
|
||||||
int const n_pad = num_seqs - num_queries;
|
|
||||||
if (n_pad && blockIdx.x == 0) {
|
|
||||||
// Handle cuda graph padding
|
|
||||||
int const offset = num_queries;
|
|
||||||
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
|
||||||
input_tokens_ptr[offset + i] = 0;
|
|
||||||
input_positions_ptr[offset + i] = 0;
|
|
||||||
slot_mapping_ptr[offset + i] = -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
|
||||||
|
|
||||||
if (blockIdx.x >= num_query_blocks) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
|
||||||
|
|
||||||
if (cur_query_id >= num_queries) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update input_tokens
|
|
||||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
|
||||||
|
|
||||||
int seq_len = seq_lens_ptr[cur_query_id];
|
|
||||||
int next_seq_len = seq_len + 1;
|
|
||||||
int next_input_pos = next_seq_len - 1;
|
|
||||||
|
|
||||||
// Update seq_lens
|
|
||||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
|
||||||
// Update input_positions
|
|
||||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
|
||||||
|
|
||||||
int const* seq_block_tables_ptr =
|
|
||||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
|
||||||
|
|
||||||
int block_index = next_input_pos / block_size;
|
|
||||||
int block_offset = next_input_pos % block_size;
|
|
||||||
|
|
||||||
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
|
|
||||||
// Update slot_mapping
|
|
||||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
|
|
||||||
int64_t const size_0, int64_t const size_1,
|
|
||||||
c10::ScalarType const type) {
|
|
||||||
bool size_0_cond = true;
|
|
||||||
if (size_0 != -1) {
|
|
||||||
size_0_cond = t.size(0) == size_0;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool size_1_cond = true;
|
|
||||||
if (size_1 != -1) {
|
|
||||||
size_1_cond = t.size(1) == size_1;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_contiguous = t.is_contiguous();
|
|
||||||
bool same_type = t.dtype() == type;
|
|
||||||
|
|
||||||
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
|
|
||||||
if (!pass) {
|
|
||||||
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
|
|
||||||
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
|
|
||||||
" is not as expected: shape = [", size_0, ", ", size_1,
|
|
||||||
"], type = ", type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// each thread processes a block per query
|
|
||||||
__global__ void advance_step_flashinfer_kernel(
|
|
||||||
int num_threads, int num_seqs, int num_queries, int block_size,
|
|
||||||
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
|
|
||||||
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
|
||||||
int const* block_tables_ptr, int64_t const block_tables_stride,
|
|
||||||
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
|
||||||
int const n_pad = num_seqs - num_queries;
|
|
||||||
if (n_pad && blockIdx.x == 0) {
|
|
||||||
// Handle cuda graph padding
|
|
||||||
int const offset = num_queries;
|
|
||||||
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
|
||||||
input_tokens_ptr[offset + i] = 0;
|
|
||||||
input_positions_ptr[offset + i] = 0;
|
|
||||||
slot_mapping_ptr[offset + i] = -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
|
||||||
|
|
||||||
if (blockIdx.x < num_query_blocks) {
|
|
||||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
|
||||||
|
|
||||||
if (cur_query_id < num_queries) {
|
|
||||||
// Update input_tokens
|
|
||||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
|
||||||
|
|
||||||
int seq_len = seq_lens_ptr[cur_query_id];
|
|
||||||
int next_seq_len = seq_len + 1;
|
|
||||||
int next_input_pos = next_seq_len - 1;
|
|
||||||
|
|
||||||
// Update seq_lens
|
|
||||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
|
||||||
// Update input_positions
|
|
||||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
|
||||||
|
|
||||||
int const* seq_block_tables_ptr =
|
|
||||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
|
||||||
|
|
||||||
int block_index = next_input_pos / block_size;
|
|
||||||
int block_offset = next_input_pos % block_size;
|
|
||||||
|
|
||||||
// Update paged_kv_last_page_len
|
|
||||||
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
|
|
||||||
|
|
||||||
int slot_num =
|
|
||||||
seq_block_tables_ptr[block_index] * block_size + block_offset;
|
|
||||||
// Update slot_mapping
|
|
||||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
|
||||||
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void advance_step_flashinfer_indptr_kernel(
|
|
||||||
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
|
|
||||||
int* block_table_bound_ptr) {
|
|
||||||
int idx = blockIdx.x * num_threads + threadIdx.x;
|
|
||||||
// Update paged_kv_indptr
|
|
||||||
if (idx == 0) {
|
|
||||||
paged_kv_indptr_ptr[idx] = 0;
|
|
||||||
}
|
|
||||||
if (idx < num_queries) {
|
|
||||||
int sum = 0;
|
|
||||||
for (int i = 0; i <= idx; ++i) {
|
|
||||||
sum += block_table_bound_ptr[i];
|
|
||||||
}
|
|
||||||
paged_kv_indptr_ptr[idx + 1] = sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void advance_step_flashinfer_indices_kernel(
|
|
||||||
int num_seqs, int num_queries, int const* block_tables_ptr,
|
|
||||||
int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr,
|
|
||||||
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
|
|
||||||
// note: max_num_blocks_per_seq = block_tables.stride(0)
|
|
||||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
|
|
||||||
// when cuda graphs are enabled, paged_kv_indptr tensor
|
|
||||||
// has to be updated for the padded queries
|
|
||||||
// tid represents a query# for paged_kv_indptr tensor
|
|
||||||
if (num_queries < tid && tid <= num_seqs) {
|
|
||||||
paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries];
|
|
||||||
}
|
|
||||||
|
|
||||||
// each thread processes a block_ptr in block_tables
|
|
||||||
// block_tables shape: [num_queries, max_num_blocks_per_seq]
|
|
||||||
// paged_kv_indices is flattened block_tables.
|
|
||||||
for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq);
|
|
||||||
idx += (gridDim.x * blockDim.x)) {
|
|
||||||
// block_tables-row = paged_kv_indptr[queryNum]
|
|
||||||
int queryNum = idx / max_num_blocks_per_seq;
|
|
||||||
int col = idx % max_num_blocks_per_seq;
|
|
||||||
if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) {
|
|
||||||
int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col;
|
|
||||||
int block_tables_idx = queryNum * max_num_blocks_per_seq + col;
|
|
||||||
paged_kv_indices_ptr[indices_arr_idx] =
|
|
||||||
block_tables_ptr[block_tables_idx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
|
|
||||||
torch::Tensor& input_tokens, // type: long
|
|
||||||
torch::Tensor& sampled_token_ids, // type: long
|
|
||||||
torch::Tensor& input_positions, // type: long
|
|
||||||
torch::Tensor& seq_lens, // type: int
|
|
||||||
torch::Tensor& slot_mapping, // type: long
|
|
||||||
torch::Tensor& block_tables) { // type: int
|
|
||||||
|
|
||||||
if (logging) {
|
|
||||||
printf("advance_step_flashattn:\n");
|
|
||||||
printf(" num_seqs = %d\n", num_seqs);
|
|
||||||
printf(" num_queries = %d\n", num_queries);
|
|
||||||
printf(" block_size = %d\n", block_size);
|
|
||||||
}
|
|
||||||
// Verify all tensors
|
|
||||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
|
||||||
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
|
||||||
at::kLong);
|
|
||||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
|
||||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
|
||||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
|
||||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
|
||||||
|
|
||||||
int dev = sampled_token_ids.get_device();
|
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
|
||||||
|
|
||||||
int blocks;
|
|
||||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
|
||||||
|
|
||||||
advance_step_flashattn_kernel<max_threads>
|
|
||||||
<<<blocks, max_threads, 0, stream>>>(
|
|
||||||
num_seqs, num_queries, block_size,
|
|
||||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
|
||||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
|
||||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
|
||||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
|
||||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
|
||||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
|
||||||
block_tables.stride(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
void advance_step_flashinfer(
|
|
||||||
int num_seqs, int num_queries, int block_size,
|
|
||||||
torch::Tensor& input_tokens, // type: long
|
|
||||||
torch::Tensor& sampled_token_ids, // type: long
|
|
||||||
torch::Tensor& input_positions, // type: long
|
|
||||||
torch::Tensor& seq_lens, // type: int
|
|
||||||
torch::Tensor& slot_mapping, // type: long
|
|
||||||
torch::Tensor& block_tables, // type: int
|
|
||||||
torch::Tensor& paged_kv_indices, // type: int
|
|
||||||
torch::Tensor& paged_kv_indptr, // type: int
|
|
||||||
torch::Tensor& paged_kv_last_page_len, // type: int
|
|
||||||
torch::Tensor& block_table_bound) { // type: int
|
|
||||||
|
|
||||||
if (logging) {
|
|
||||||
printf("advance_step_flashinfer:\n");
|
|
||||||
printf(" num_seqs = %d\n", num_seqs);
|
|
||||||
printf(" num_queries = %d\n", num_queries);
|
|
||||||
printf(" block_size = %d\n", block_size);
|
|
||||||
printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0));
|
|
||||||
}
|
|
||||||
// Verify all tensors
|
|
||||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
|
||||||
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
|
||||||
// at::kLong);
|
|
||||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
|
||||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
|
||||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
|
||||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
|
||||||
|
|
||||||
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
|
|
||||||
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
|
|
||||||
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
|
|
||||||
at::kInt);
|
|
||||||
|
|
||||||
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
|
|
||||||
|
|
||||||
int dev = sampled_token_ids.get_device();
|
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
|
||||||
|
|
||||||
int blocks;
|
|
||||||
int threads;
|
|
||||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
|
||||||
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
|
||||||
|
|
||||||
TORCH_CHECK((blocks * threads > num_queries),
|
|
||||||
"multi-step: not enough threads to map to num_queries = ",
|
|
||||||
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
|
|
||||||
" blocks = ", blocks, " max_threads = ", threads);
|
|
||||||
if (logging) {
|
|
||||||
printf("launching kernels with %d blocks and %d threads\n", blocks,
|
|
||||||
threads);
|
|
||||||
}
|
|
||||||
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
|
|
||||||
threads, num_seqs, num_queries, block_size,
|
|
||||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
|
||||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
|
||||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
|
||||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
|
||||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
|
||||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
|
||||||
block_tables.stride(0),
|
|
||||||
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
|
|
||||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
|
||||||
|
|
||||||
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
|
|
||||||
threads, num_seqs, num_queries,
|
|
||||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
|
||||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
|
||||||
|
|
||||||
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
|
|
||||||
num_seqs, num_queries,
|
|
||||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
|
||||||
block_tables.stride(0),
|
|
||||||
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
|
|
||||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
|
||||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace prepare_inputs
|
|
||||||
|
|
||||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
|
||||||
int64_t block_size, torch::Tensor& input_tokens,
|
|
||||||
torch::Tensor& sampled_token_ids,
|
|
||||||
torch::Tensor& input_positions,
|
|
||||||
torch::Tensor& seq_lens,
|
|
||||||
torch::Tensor& slot_mapping,
|
|
||||||
torch::Tensor& block_tables) {
|
|
||||||
prepare_inputs::advance_step_flashattn(
|
|
||||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
|
||||||
input_positions, seq_lens, slot_mapping, block_tables);
|
|
||||||
}
|
|
||||||
|
|
||||||
void advance_step_flashinfer(
|
|
||||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
|
||||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
|
||||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
|
||||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
|
||||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
|
|
||||||
prepare_inputs::advance_step_flashinfer(
|
|
||||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
|
||||||
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
|
|
||||||
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
|
|
||||||
}
|
|
||||||
@ -1,19 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
namespace prepare_inputs {
|
|
||||||
|
|
||||||
static constexpr int max_threads = 256;
|
|
||||||
static constexpr bool logging = false;
|
|
||||||
|
|
||||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
|
||||||
|
|
||||||
} // namespace prepare_inputs
|
|
||||||
@ -470,11 +470,12 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
|
|||||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t2>
|
template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
|
||||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
__device__ inline void dequant_fp8_scales<half2, vllm::kFE4M3fn.id()>(
|
||||||
|
int q, half2* frag_b) {
|
||||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||||
;
|
;
|
||||||
q <<= 8;
|
q <<= 8;
|
||||||
@ -486,8 +487,8 @@ __device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
|
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||||
nv_bfloat162* frag_b) {
|
int q, nv_bfloat162* frag_b) {
|
||||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||||
constexpr int MASK = 0x7F007F00;
|
constexpr int MASK = 0x7F007F00;
|
||||||
@ -502,6 +503,20 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
|
|||||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
|
||||||
|
int q, nv_bfloat162* frag_b) {
|
||||||
|
// In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16,
|
||||||
|
// but we assume that such a extreme value would not occur in real models.
|
||||||
|
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||||
|
q <<= 7;
|
||||||
|
int Out2 = q & 0x7F807F80;
|
||||||
|
|
||||||
|
// Note: reverse indexing is intentional because weights are permuted
|
||||||
|
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||||
|
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace MARLIN_NAMESPACE_NAME
|
} // namespace MARLIN_NAMESPACE_NAME
|
||||||
|
|||||||
@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME {
|
|||||||
TEMPLATE = ("template __global__ void Marlin<"
|
TEMPLATE = ("template __global__ void Marlin<"
|
||||||
"{{scalar_t}}, "
|
"{{scalar_t}}, "
|
||||||
"{{w_type_id}}, "
|
"{{w_type_id}}, "
|
||||||
|
"{{s_type_id}}, "
|
||||||
"{{threads}}, "
|
"{{threads}}, "
|
||||||
"{{thread_m_blocks}}, "
|
"{{thread_m_blocks}}, "
|
||||||
"{{thread_n_blocks}}, "
|
"{{thread_n_blocks}}, "
|
||||||
@ -78,7 +79,8 @@ def generate_new_kernels():
|
|||||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||||
continue
|
continue
|
||||||
# nvfp4 only supports group_size == 16
|
# nvfp4 only supports group_size == 16
|
||||||
if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
|
# mxfp4 only supports group_size == 32
|
||||||
|
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||||
continue
|
continue
|
||||||
# other quantization methods don't support group_size = 16
|
# other quantization methods don't support group_size = 16
|
||||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
||||||
@ -97,10 +99,23 @@ def generate_new_kernels():
|
|||||||
# 4bit quantization and fp16
|
# 4bit quantization and fp16
|
||||||
is_zp_float_list.append(True)
|
is_zp_float_list.append(True)
|
||||||
|
|
||||||
|
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
|
||||||
|
s_type = "vllm::kFE4M3fn"
|
||||||
|
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
|
||||||
|
s_type = "vllm::kFE8M0fnu"
|
||||||
|
if dtype == "fp16":
|
||||||
|
# we cannot safely dequantize e8m0 to fp16, so skip this
|
||||||
|
continue
|
||||||
|
elif dtype == "fp16":
|
||||||
|
s_type = "vllm::kFloat16"
|
||||||
|
elif dtype == "bf16":
|
||||||
|
s_type = "vllm::kBFloat16"
|
||||||
|
|
||||||
for is_zp_float in is_zp_float_list:
|
for is_zp_float in is_zp_float_list:
|
||||||
template_str = jinja2.Template(TEMPLATE).render(
|
template_str = jinja2.Template(TEMPLATE).render(
|
||||||
scalar_t=c_dtype,
|
scalar_t=c_dtype,
|
||||||
w_type_id=scalar_type + ".id()",
|
w_type_id=scalar_type + ".id()",
|
||||||
|
s_type_id=s_type + ".id()",
|
||||||
threads=threads,
|
threads=threads,
|
||||||
thread_m_blocks=max(m_blocks, 1),
|
thread_m_blocks=max(m_blocks, 1),
|
||||||
thread_n_blocks=n_blocks,
|
thread_n_blocks=n_blocks,
|
||||||
|
|||||||
@ -48,7 +48,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
|||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(
|
torch::Tensor gptq_marlin_gemm(
|
||||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
torch::Tensor& b_q_weight,
|
||||||
|
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||||
@ -187,7 +188,12 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
|||||||
int tb_m = thread_m_blocks * 16;
|
int tb_m = thread_m_blocks * 16;
|
||||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||||
int sh_red_size = tb_m * (tb_n + 8);
|
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||||
|
int sh_bias_size = tb_n * 2;
|
||||||
|
int tmp_size =
|
||||||
|
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
|
||||||
|
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
|
||||||
|
|
||||||
int sh_s_size =
|
int sh_s_size =
|
||||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||||
group_size, has_act_order, is_k_full);
|
group_size, has_act_order, is_k_full);
|
||||||
@ -202,8 +208,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
|||||||
sh_zp_size = sh_s_size / 2;
|
sh_zp_size = sh_s_size / 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
|
int total_size =
|
||||||
sh_zp_size + sh_g_idx_size;
|
tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size;
|
||||||
|
|
||||||
return total_size;
|
return total_size;
|
||||||
}
|
}
|
||||||
@ -237,20 +243,25 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
|||||||
int cache_size = get_kernel_cache_size(
|
int cache_size = get_kernel_cache_size(
|
||||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||||
return cache_size <= max_shared_mem;
|
return cache_size + 512 <= max_shared_mem;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||||
is_zp_float == IS_ZP_FLOAT) { \
|
is_zp_float == IS_ZP_FLOAT) { \
|
||||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
constexpr auto S_TYPE = \
|
||||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
W_TYPE == vllm::kFE2M1f \
|
||||||
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
|
||||||
|
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
|
||||||
|
: vllm::kBFloat16); \
|
||||||
|
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
|
||||||
|
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||||
|
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||||
}
|
}
|
||||||
|
|
||||||
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||||
@ -315,22 +326,39 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
|||||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||||
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
|
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||||
|
|
||||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||||
|
|
||||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||||
|
|
||||||
#define FP4_GET_IF(W_TYPE) \
|
#define NVFP4_GET_IF(W_TYPE) \
|
||||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||||
FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||||
FP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||||
|
|
||||||
|
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define MXFP4_GET_IF(W_TYPE) \
|
||||||
|
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||||
|
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||||
|
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||||
|
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||||
|
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||||
|
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||||
|
|
||||||
// We currently have 4-bit models only with group_blocks == 4
|
// We currently have 4-bit models only with group_blocks == 4
|
||||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
@ -384,7 +412,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
|||||||
COMMON_GET_IF(vllm::kU4B8)
|
COMMON_GET_IF(vllm::kU4B8)
|
||||||
COMMON_GET_IF(vllm::kU8B128)
|
COMMON_GET_IF(vllm::kU8B128)
|
||||||
|
|
||||||
FP4_GET_IF(vllm::kFE2M1f)
|
NVFP4_GET_IF(vllm::kFE2M1f)
|
||||||
|
|
||||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||||
|
|
||||||
@ -396,6 +424,11 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
|||||||
}
|
}
|
||||||
FZP_GET_IF(vllm::kU4)
|
FZP_GET_IF(vllm::kU4)
|
||||||
}
|
}
|
||||||
|
if (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||||
|
if (false) {
|
||||||
|
}
|
||||||
|
MXFP4_GET_IF(vllm::kFE2M1f)
|
||||||
|
}
|
||||||
|
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
@ -453,12 +486,12 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
void* s, void* s2, void* zp, void* g_idx, void* perm,
|
||||||
int prob_m, int prob_n, int prob_k, int lda, void* workspace,
|
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda,
|
||||||
vllm::ScalarType const& q_type, bool has_act_order,
|
void* workspace, vllm::ScalarType const& q_type, bool has_bias,
|
||||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||||
int dev, cudaStream_t stream, int thread_k_init,
|
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||||
int thread_n_init, int sms, bool use_atomic_add,
|
int thread_n_init, int sms, bool use_atomic_add,
|
||||||
bool use_fp32_reduce, bool is_zp_float) {
|
bool use_fp32_reduce, bool is_zp_float) {
|
||||||
if (has_zp) {
|
if (has_zp) {
|
||||||
@ -503,6 +536,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
const int4* B_ptr = (const int4*)B;
|
const int4* B_ptr = (const int4*)B;
|
||||||
int4* C_ptr = (int4*)C;
|
int4* C_ptr = (int4*)C;
|
||||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||||
|
const int4* bias_ptr = (const int4*)b_bias;
|
||||||
const int4* s_ptr = (const int4*)s;
|
const int4* s_ptr = (const int4*)s;
|
||||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||||
const int4* zp_ptr = (const int4*)zp;
|
const int4* zp_ptr = (const int4*)zp;
|
||||||
@ -623,8 +657,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
// avoid ">>>" being formatted to "> > >"
|
// avoid ">>>" being formatted to "> > >"
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
|
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
|
||||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups,
|
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr,
|
||||||
prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add,
|
g_idx_ptr, num_groups,
|
||||||
|
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
|
||||||
use_fp32_reduce, max_shared_mem_new);
|
use_fp32_reduce, max_shared_mem_new);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
@ -638,7 +673,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(
|
torch::Tensor gptq_marlin_gemm(
|
||||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
torch::Tensor& b_q_weight,
|
||||||
|
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||||
@ -785,12 +821,24 @@ torch::Tensor gptq_marlin_gemm(
|
|||||||
torch::Tensor global_scale;
|
torch::Tensor global_scale;
|
||||||
if (global_scale_or_none.has_value()) {
|
if (global_scale_or_none.has_value()) {
|
||||||
global_scale = global_scale_or_none.value();
|
global_scale = global_scale_or_none.value();
|
||||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
|
||||||
"global_scale can only be used for float4_e2m1f.");
|
"global_scale can only be used for nvfp4 format.");
|
||||||
} else {
|
} else {
|
||||||
global_scale = torch::empty({0}, options);
|
global_scale = torch::empty({0}, options);
|
||||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
|
||||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
"the global_scale parameter must be passed for nvfp4 format.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_bias = b_bias_or_none.has_value();
|
||||||
|
torch::Tensor b_bias;
|
||||||
|
if (has_bias) {
|
||||||
|
b_bias = b_bias_or_none.value();
|
||||||
|
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
|
||||||
|
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
|
||||||
|
TORCH_CHECK(b_bias.size(0) == size_n, "b_bias.size(0) != size_n");
|
||||||
|
TORCH_CHECK(b_bias.stride(0) == 1, "b_bias.stride(0) != 1");
|
||||||
|
} else {
|
||||||
|
b_bias = torch::empty({0}, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor b_zeros;
|
torch::Tensor b_zeros;
|
||||||
@ -857,34 +905,50 @@ torch::Tensor gptq_marlin_gemm(
|
|||||||
if (a.scalar_type() == at::ScalarType::Half) {
|
if (a.scalar_type() == at::ScalarType::Half) {
|
||||||
void* scales_ptr;
|
void* scales_ptr;
|
||||||
if (b_q_type == vllm::kFE2M1f) {
|
if (b_q_type == vllm::kFE2M1f) {
|
||||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
if (group_size == 16)
|
||||||
|
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||||
|
else if (group_size == 32)
|
||||||
|
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||||
|
else
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||||
|
"and group_size == 32 (MXFP4)");
|
||||||
} else {
|
} else {
|
||||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||||
}
|
}
|
||||||
|
|
||||||
marlin::marlin_mm<half>(
|
marlin::marlin_mm<half>(
|
||||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
|
||||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0),
|
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order,
|
||||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
is_k_full, has_zp, num_groups, group_size, dev,
|
||||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
|
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
void* scales_ptr;
|
void* scales_ptr;
|
||||||
if (b_q_type == vllm::kFE2M1f) {
|
if (b_q_type == vllm::kFE2M1f) {
|
||||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
if (group_size == 16)
|
||||||
|
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||||
|
else if (group_size == 32)
|
||||||
|
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||||
|
else
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||||
|
"and group_size == 32 (MXFP4)");
|
||||||
} else {
|
} else {
|
||||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||||
}
|
}
|
||||||
|
|
||||||
marlin::marlin_mm<nv_bfloat16>(
|
marlin::marlin_mm<nv_bfloat16>(
|
||||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||||
|
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
|
||||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||||
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
|
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
|
||||||
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -10,15 +10,18 @@
|
|||||||
#define MARLIN_KERNEL_PARAMS \
|
#define MARLIN_KERNEL_PARAMS \
|
||||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||||
|
const int4 *__restrict__ b_bias_ptr, \
|
||||||
const int4 *__restrict__ scales_ptr, \
|
const int4 *__restrict__ scales_ptr, \
|
||||||
const uint16_t *__restrict__ scale2_ptr, \
|
const uint16_t *__restrict__ scale2_ptr, \
|
||||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||||
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||||
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
|
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
|
||||||
|
int max_shared_mem
|
||||||
|
|
||||||
namespace MARLIN_NAMESPACE_NAME {
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
|
const vllm::ScalarTypeId s_type_id, // weight ScalarType id
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
|
|||||||
@ -39,6 +39,7 @@ namespace MARLIN_NAMESPACE_NAME {
|
|||||||
|
|
||||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
|
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
@ -271,6 +272,7 @@ __device__ inline void wait_negative_and_add(int* lock) {
|
|||||||
|
|
||||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
|
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||||
const int threads, // number of threads in a threadblock
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
@ -290,6 +292,7 @@ __global__ void Marlin(
|
|||||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||||
|
const int4* __restrict__ b_bias_ptr,
|
||||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||||
// (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||||
@ -297,12 +300,13 @@ __global__ void Marlin(
|
|||||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||||
// (k/groupsize)x(n/pack_factor)
|
// (k/groupsize)x(n/pack_factor)
|
||||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||||
int num_groups, // number of scale groups per output channel
|
int num_groups, // number of scale groups per output channel
|
||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
int lda, // A.stride(0), equal to prob_k is A is contiguous
|
int lda, // A.stride(0), equal to prob_k is A is contiguous
|
||||||
int* locks, // extra global storage for barrier synchronization
|
int* locks, // extra global storage for barrier synchronization
|
||||||
|
bool has_bias,
|
||||||
bool use_atomic_add, // whether to use atomic add to reduce
|
bool use_atomic_add, // whether to use atomic add to reduce
|
||||||
bool use_fp32_reduce, // whether to use fp32 global reduce
|
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||||
int max_shared_mem) {
|
int max_shared_mem) {
|
||||||
@ -326,18 +330,29 @@ __global__ void Marlin(
|
|||||||
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
||||||
|
|
||||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||||
|
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
|
||||||
|
if constexpr (w_type == vllm::kFE2M1f) {
|
||||||
|
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||||
|
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||||
|
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||||
|
static_assert(s_type == vllm::kBFloat16);
|
||||||
|
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||||
|
static_assert(s_type == vllm::kFloat16);
|
||||||
|
}
|
||||||
|
|
||||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||||
// see comments of dequant.h for more details
|
// see comments of dequant.h for more details
|
||||||
constexpr bool dequant_skip_flop =
|
constexpr bool dequant_skip_flop =
|
||||||
!is_int_type ||
|
w_type == vllm::kFE4M3fn ||
|
||||||
|
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||||
|
|
||||||
scalar_t2 global_scale;
|
scalar_t2 global_scale;
|
||||||
|
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||||
if constexpr (w_type == vllm::kFE2M1f) {
|
// NVFP4 format requires global scale
|
||||||
uint16_t val = scale2_ptr[0];
|
uint16_t val = scale2_ptr[0];
|
||||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||||
}
|
}
|
||||||
@ -589,7 +604,7 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
(threadIdx.x % 32) / 4;
|
(threadIdx.x % 32) / 4;
|
||||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
|
||||||
|
|
||||||
} else if constexpr (group_blocks != -1)
|
} else if constexpr (group_blocks != -1)
|
||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
@ -602,6 +617,18 @@ __global__ void Marlin(
|
|||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
(threadIdx.x % 32) % 4;
|
(threadIdx.x % 32) % 4;
|
||||||
|
|
||||||
|
int bias_sh_rd;
|
||||||
|
if constexpr (m_block_size_8) {
|
||||||
|
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
|
(threadIdx.x % 32) / 8;
|
||||||
|
} else {
|
||||||
|
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
|
(threadIdx.x % 32) % 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
int bias_sh_wr = threadIdx.x;
|
||||||
|
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||||
|
|
||||||
// Zero-points have the same read layout as the scales
|
// Zero-points have the same read layout as the scales
|
||||||
// (without column-wise case)
|
// (without column-wise case)
|
||||||
constexpr int num_col_threads = 8;
|
constexpr int num_col_threads = 8;
|
||||||
@ -670,7 +697,19 @@ __global__ void Marlin(
|
|||||||
constexpr int sh_b_size = stages * b_sh_stage;
|
constexpr int sh_b_size = stages * b_sh_stage;
|
||||||
int4* sh_b = sh;
|
int4* sh_b = sh;
|
||||||
int4* sh_red = sh;
|
int4* sh_red = sh;
|
||||||
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
|
||||||
|
constexpr int sh_size_b_red_min =
|
||||||
|
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
|
||||||
|
constexpr int sh_size_b_red_max =
|
||||||
|
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||||
|
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
|
||||||
|
constexpr int sh_b_red_bias_size =
|
||||||
|
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
|
||||||
|
? sh_size_b_red_max
|
||||||
|
: (sh_size_b_red_min + sh_bias_size);
|
||||||
|
|
||||||
|
int4* sh_bias = sh + sh_size_b_red_min;
|
||||||
|
int4* sh_g_idx = sh + sh_b_red_bias_size;
|
||||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||||
: (stages * s_sh_stage);
|
: (stages * s_sh_stage);
|
||||||
@ -680,15 +719,13 @@ __global__ void Marlin(
|
|||||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||||
stages * b_sh_stage);
|
stages * b_sh_stage);
|
||||||
int4* sh_a = sh_s + sh_s_size;
|
int4* sh_a = sh_s + sh_s_size;
|
||||||
// constexpr int shm_size_used =
|
|
||||||
// stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
|
|
||||||
// (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
|
||||||
|
|
||||||
// Register storage for double buffer of shared memory reads.
|
// Register storage for double buffer of shared memory reads.
|
||||||
FragA frag_a[2][thread_m_blocks];
|
FragA frag_a[2][thread_m_blocks];
|
||||||
I4 frag_b_quant[2][b_thread_vecs];
|
I4 frag_b_quant[2][b_thread_vecs];
|
||||||
FragC frag_c[thread_m_blocks][4][2];
|
FragC frag_c[thread_m_blocks][4][2];
|
||||||
FragS frag_s[2][4]; // No act-order
|
FragS frag_s[2][4]; // No act-order
|
||||||
|
FragS frag_bias[2][4];
|
||||||
FragS act_frag_s[2][4][4]; // For act-order
|
FragS act_frag_s[2][4][4]; // For act-order
|
||||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||||
FragZP frag_zp; // Zero-points in fp16
|
FragZP frag_zp; // Zero-points in fp16
|
||||||
@ -923,10 +960,15 @@ __global__ void Marlin(
|
|||||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||||
} else {
|
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
|
||||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||||
reinterpret_cast<int2*>(
|
reinterpret_cast<int2*>(
|
||||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||||
|
} else {
|
||||||
|
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||||
|
reinterpret_cast<int2*>(
|
||||||
|
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
|
||||||
|
k % 2];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1139,9 +1181,9 @@ __global__ void Marlin(
|
|||||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||||
|
|
||||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||||
dequant_fp8_scales<scalar_t2>(
|
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1411,7 +1453,7 @@ __global__ void Marlin(
|
|||||||
// Write out the reduce final result in the correct layout. We only actually
|
// Write out the reduce final result in the correct layout. We only actually
|
||||||
// reshuffle matrix fragments in this step, the reduction above is performed
|
// reshuffle matrix fragments in this step, the reduction above is performed
|
||||||
// in fragment layout.
|
// in fragment layout.
|
||||||
auto write_result = [&]() {
|
auto write_result = [&](bool last) {
|
||||||
int c_gl_stride = prob_n / 8;
|
int c_gl_stride = prob_n / 8;
|
||||||
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
||||||
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
||||||
@ -1438,7 +1480,7 @@ __global__ void Marlin(
|
|||||||
int c_gl_wr_end = c_gl_stride * prob_m;
|
int c_gl_wr_end = c_gl_stride * prob_m;
|
||||||
// We first reorder in shared memory to guarantee the most efficient final
|
// We first reorder in shared memory to guarantee the most efficient final
|
||||||
// global write patterns
|
// global write patterns
|
||||||
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
||||||
scalar_t2 res =
|
scalar_t2 res =
|
||||||
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
||||||
|
|
||||||
@ -1447,12 +1489,25 @@ __global__ void Marlin(
|
|||||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||||
w_type.size_bits() == 4 &&
|
w_type.size_bits() == 4 &&
|
||||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||||
res = __hmul2(res, s[0]);
|
scalar_t2 tmp_scale = s[0];
|
||||||
|
if constexpr (m_block_size_8) {
|
||||||
|
tmp_scale = Dtype::num2num2(
|
||||||
|
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
|
||||||
|
}
|
||||||
|
res = __hmul2(res, tmp_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (w_type == vllm::kFE2M1f) {
|
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||||
res = __hmul2(res, global_scale);
|
res = __hmul2(res, global_scale);
|
||||||
}
|
}
|
||||||
|
if (has_bias && last) {
|
||||||
|
scalar_t2 tmp_bias = b_bias[0];
|
||||||
|
if constexpr (m_block_size_8) {
|
||||||
|
tmp_bias = Dtype::num2num2(
|
||||||
|
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
|
||||||
|
}
|
||||||
|
res = __hadd2(res, tmp_bias);
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (m_block_size_8) {
|
if constexpr (m_block_size_8) {
|
||||||
((scalar_t*)sh_red)[idx] = res.x;
|
((scalar_t*)sh_red)[idx] = res.x;
|
||||||
@ -1470,19 +1525,25 @@ __global__ void Marlin(
|
|||||||
if constexpr (m_block_size_8) {
|
if constexpr (m_block_size_8) {
|
||||||
int wr = c_sh_wr + 16 * j;
|
int wr = c_sh_wr + 16 * j;
|
||||||
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
|
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
|
||||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
frag_s[j / 2][2 * (j % 2) + 0],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||||
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
|
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
|
||||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
frag_s[j / 2][2 * (j % 2) + 1],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||||
} else {
|
} else {
|
||||||
int wr = c_sh_wr + 8 * j;
|
int wr = c_sh_wr + 8 * j;
|
||||||
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
||||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||||
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
||||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||||
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
||||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||||
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
||||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
|
||||||
|
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c_sh_wr += 16 * (4 * c_sh_stride);
|
c_sh_wr += 16 * (4 * c_sh_stride);
|
||||||
@ -1622,6 +1683,14 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
|
|
||||||
thread_block_reduce();
|
thread_block_reduce();
|
||||||
|
|
||||||
|
if (has_bias && last) {
|
||||||
|
__syncthreads();
|
||||||
|
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
|
||||||
|
threadIdx.x < 16 * thread_n_blocks / 8);
|
||||||
|
cp_async_fence();
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||||
@ -1684,11 +1753,20 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
barrier_release(&locks[locks_off], last);
|
barrier_release(&locks[locks_off], last);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (has_bias && last) {
|
||||||
|
cp_async_wait<0>();
|
||||||
|
__syncthreads();
|
||||||
|
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
|
||||||
|
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
|
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
|
||||||
wait_negative_and_add(&locks[locks_off]);
|
wait_negative_and_add(&locks[locks_off]);
|
||||||
if (last || use_atomic_add)
|
if (last || use_atomic_add)
|
||||||
// only the last block in a slice actually writes the result
|
// only the last block in a slice actually writes the result
|
||||||
write_result();
|
write_result(last);
|
||||||
slice_row = 0;
|
slice_row = 0;
|
||||||
slice_col_par++;
|
slice_col_par++;
|
||||||
slice_col++;
|
slice_col++;
|
||||||
@ -1706,6 +1784,7 @@ __global__ void Marlin(
|
|||||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||||
// Update slice k/n for scales loading
|
// Update slice k/n for scales loading
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
slice_k_start = tb_k * slice_row;
|
slice_k_start = tb_k * slice_row;
|
||||||
|
|||||||
@ -142,25 +142,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||||
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
||||||
|
|
||||||
// prepare_inputs advance_step
|
|
||||||
ops.def(
|
|
||||||
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
|
|
||||||
"Tensor! input_tokens, Tensor sampled_token_ids, "
|
|
||||||
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
|
|
||||||
"Tensor block_tables) -> ()");
|
|
||||||
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
|
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"advance_step_flashinfer("
|
|
||||||
" int num_seqs, int num_queries, int block_size,"
|
|
||||||
" Tensor! input_tokens, Tensor sampled_token_ids,"
|
|
||||||
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
|
|
||||||
" Tensor block_tables, Tensor! paged_kv_indices,"
|
|
||||||
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
|
|
||||||
" Tensor! block_table_bounds"
|
|
||||||
") -> ()");
|
|
||||||
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
|
||||||
|
|
||||||
// Layernorm
|
// Layernorm
|
||||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
ops.def(
|
ops.def(
|
||||||
@ -326,6 +307,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||||
|
"Tensor? b_bias_or_none,"
|
||||||
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
|
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
|
||||||
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||||
|
|||||||
@ -497,14 +497,11 @@ ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
|||||||
# Copy in the v1 package for testing (it isn't distributed yet)
|
# Copy in the v1 package for testing (it isn't distributed yet)
|
||||||
COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
|
COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
|
||||||
|
|
||||||
# doc requires source code
|
# Source code is used in the `python_only_compile.sh` test
|
||||||
# we hide them inside `test_docs/` , so that this source code
|
# We hide it inside `src/` so that this source code
|
||||||
# will not be imported by other tests
|
# will not be imported by other tests
|
||||||
RUN mkdir test_docs
|
RUN mkdir src
|
||||||
RUN mv docs test_docs/
|
RUN mv vllm src/vllm
|
||||||
RUN cp -r examples test_docs/
|
|
||||||
RUN mv vllm test_docs/
|
|
||||||
RUN mv mkdocs.yaml test_docs/
|
|
||||||
#################### TEST IMAGE ####################
|
#################### TEST IMAGE ####################
|
||||||
|
|
||||||
#################### OPENAI API SERVER ####################
|
#################### OPENAI API SERVER ####################
|
||||||
|
|||||||
@ -2,4 +2,5 @@ Loading Model weights with fastsafetensors
|
|||||||
===================================================================
|
===================================================================
|
||||||
|
|
||||||
Using fastsafetensors library enables loading model weights to GPU memory by leveraging GPU direct storage. See [their GitHub repository](https://github.com/foundation-model-stack/fastsafetensors) for more details.
|
Using fastsafetensors library enables loading model weights to GPU memory by leveraging GPU direct storage. See [their GitHub repository](https://github.com/foundation-model-stack/fastsafetensors) for more details.
|
||||||
For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true``
|
|
||||||
|
To enable this feature, use the ``--load-format fastsafetensors`` command-line argument
|
||||||
|
|||||||
@ -35,6 +35,7 @@ You can check if this is happening by trying the old defaults with `--generation
|
|||||||
If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue:
|
If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue:
|
||||||
|
|
||||||
- `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging.
|
- `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging.
|
||||||
|
- `export VLLM_LOG_STATS_INTERVAL=1.` to get log statistics more frequently for tracking running queue, waiting queue and cache hit states.
|
||||||
- `export CUDA_LAUNCH_BLOCKING=1` to identify which CUDA kernel is causing the problem.
|
- `export CUDA_LAUNCH_BLOCKING=1` to identify which CUDA kernel is causing the problem.
|
||||||
- `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL.
|
- `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL.
|
||||||
- `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time.
|
- `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time.
|
||||||
|
|||||||
@ -3,7 +3,8 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import weakref
|
import weakref
|
||||||
from contextlib import ExitStack
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -32,27 +33,130 @@ def temporary_environ(env_vars):
|
|||||||
os.environ[k] = v
|
os.environ[k] = v
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackendConfig:
|
||||||
|
name: str
|
||||||
|
env_vars: dict
|
||||||
|
comp_config: dict
|
||||||
|
specific_gpu_arch: Optional[tuple] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Define all backend configurations of full cudagraph to be tested
|
||||||
|
backend_configs = {
|
||||||
|
# FA3 on Hopper
|
||||||
|
"FA3":
|
||||||
|
BackendConfig(name="FA3",
|
||||||
|
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(9, 0)),
|
||||||
|
# FlashMLA on Hopper
|
||||||
|
"FlashMLA":
|
||||||
|
BackendConfig(name="FlashMLA",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||||
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(9, 0)),
|
||||||
|
# Cutlass MLA on Blackwell
|
||||||
|
"CutlassMLA":
|
||||||
|
BackendConfig(
|
||||||
|
name="CutlassMLA",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_USE_V1": "1",
|
||||||
|
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||||
|
"FORCE_NUM_KV_SPLITS":
|
||||||
|
"1", # TODO: remove this when hang issue is fixed
|
||||||
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(10, 0)),
|
||||||
|
# FA2
|
||||||
|
"FA2":
|
||||||
|
BackendConfig(name="FA2",
|
||||||
|
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL",
|
||||||
|
}),
|
||||||
|
# Triton Attention
|
||||||
|
"TritonAttn":
|
||||||
|
BackendConfig(name="TritonAttn",
|
||||||
|
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL",
|
||||||
|
}),
|
||||||
|
# FlashInfer
|
||||||
|
"FlashInfer":
|
||||||
|
BackendConfig(name="FlashInfer",
|
||||||
|
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
test_params_full_cudagraph = []
|
||||||
|
|
||||||
|
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||||
|
MLA_backends = ["FlashMLA", "CutlassMLA"]
|
||||||
|
for mla_backend in MLA_backends:
|
||||||
|
test_params_full_cudagraph.append(
|
||||||
|
pytest.param(
|
||||||
|
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
|
||||||
|
|
||||||
|
# Qwen/Qwen2-1.5B-Instruct with other backends
|
||||||
|
other_backend_configs = [
|
||||||
|
backend_configs[c] for c in backend_configs if c not in MLA_backends
|
||||||
|
]
|
||||||
|
for backend_config in other_backend_configs:
|
||||||
|
test_params_full_cudagraph.append(
|
||||||
|
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@pytest.fixture(scope="class")
|
||||||
def llm_pair(request):
|
def llm_pair(request):
|
||||||
model = request.param
|
model, backend_config = request.param
|
||||||
|
|
||||||
with temporary_environ({
|
# Dynamically skip test if GPU capability is not met
|
||||||
"VLLM_USE_V1": "1",
|
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
|
||||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
!= current_platform.get_device_capability():
|
||||||
}):
|
if backend_config.specific_gpu_arch == (9, 0):
|
||||||
|
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||||
|
elif backend_config.specific_gpu_arch == (10, 0):
|
||||||
|
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
|
||||||
|
|
||||||
|
env_vars = {
|
||||||
|
"VLLM_USE_V1": "1",
|
||||||
|
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
||||||
|
# when per-request generators are not used in V1.
|
||||||
|
"VLLM_USE_FLASHINFER_SAMPLER": "0",
|
||||||
|
**backend_config.env_vars,
|
||||||
|
}
|
||||||
|
with temporary_environ(env_vars):
|
||||||
full = LLM(
|
full = LLM(
|
||||||
model=model,
|
model=model,
|
||||||
gpu_memory_utilization=0.45,
|
gpu_memory_utilization=0.43,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
compilation_config=CompilationConfig(full_cuda_graph=True),
|
max_num_seqs=128,
|
||||||
|
compilation_config=\
|
||||||
|
CompilationConfig(**backend_config.comp_config),
|
||||||
|
generation_config="vllm",
|
||||||
|
seed=42,
|
||||||
)
|
)
|
||||||
piecewise = LLM(
|
piecewise = LLM(
|
||||||
model=model,
|
model=model,
|
||||||
gpu_memory_utilization=0.45,
|
gpu_memory_utilization=0.43,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
compilation_config=CompilationConfig(),
|
max_num_seqs=128,
|
||||||
|
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
|
||||||
|
generation_config="vllm",
|
||||||
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
# PyTest caches the fixture values so we use weakref.proxy to enable GC
|
# PyTest caches the fixture values so we use weakref.proxy to enable GC
|
||||||
@ -66,16 +170,7 @@ def llm_pair(request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
|
||||||
"llm_pair",
|
|
||||||
[
|
|
||||||
# Model names for the llm_pair fixture
|
|
||||||
"deepseek-ai/DeepSeek-V2-Lite",
|
|
||||||
"Qwen/Qwen2-1.5B-Instruct"
|
|
||||||
],
|
|
||||||
indirect=True)
|
|
||||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
|
||||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
|
||||||
class TestFullCUDAGraph:
|
class TestFullCUDAGraph:
|
||||||
"""
|
"""
|
||||||
Use a class such that an llm pair is constructed once for all
|
Use a class such that an llm pair is constructed once for all
|
||||||
@ -104,12 +199,14 @@ class TestFullCUDAGraph:
|
|||||||
full cudagraph compilation works for padded cases too.
|
full cudagraph compilation works for padded cases too.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
piecewise_llm, full_cudagraph_llm = llm_pair
|
full_cudagraph_llm, piecewise_llm = llm_pair
|
||||||
|
|
||||||
prompts = ["Hello, my name is"] * batch_size
|
prompts = ["the quick brown fox"] * batch_size
|
||||||
|
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
||||||
|
# that can amplify tiny numeric differences across runtimes.
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
top_p=0.95)
|
top_p=1.0)
|
||||||
|
|
||||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||||
@ -117,42 +214,16 @@ class TestFullCUDAGraph:
|
|||||||
# Check that all responses are the same
|
# Check that all responses are the same
|
||||||
for piecewise_res, full_res in zip(piecewise_responses,
|
for piecewise_res, full_res in zip(piecewise_responses,
|
||||||
full_responses):
|
full_responses):
|
||||||
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
|
assert piecewise_res.outputs[0].text.lower() == \
|
||||||
|
full_res.outputs[0].text.lower()
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model, supported",
|
|
||||||
[
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", True),
|
|
||||||
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
|
|
||||||
("deepseek-ai/DeepSeek-V2-Lite", False),
|
|
||||||
])
|
|
||||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
|
||||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
|
||||||
def test_lower_max_num_seqs(model, supported):
|
|
||||||
with temporary_environ({
|
|
||||||
"VLLM_USE_V1": "1",
|
|
||||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
|
||||||
}), ExitStack() as stack:
|
|
||||||
if not supported:
|
|
||||||
stack.enter_context(pytest.raises(RuntimeError))
|
|
||||||
|
|
||||||
llm = LLM(model=model,
|
|
||||||
max_num_seqs=256,
|
|
||||||
trust_remote_code=True,
|
|
||||||
max_model_len=1024,
|
|
||||||
compilation_config=CompilationConfig(
|
|
||||||
full_cuda_graph=True,
|
|
||||||
cudagraph_capture_sizes=[64, 256, 512]))
|
|
||||||
llm.generate(["Hello, my name is"] * 10)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||||
def test_full_cudagraph_with_invalid_backend():
|
def test_full_cudagraph_with_invalid_backend():
|
||||||
with temporary_environ({
|
with temporary_environ({
|
||||||
"VLLM_USE_V1": "1",
|
"VLLM_USE_V1": "1",
|
||||||
"VLLM_FLASH_ATTN_VERSION":
|
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
|
||||||
"2" #FA2 not supported with full_cuda_graph
|
# Flex_Attention is not supported with full cuda graph
|
||||||
}), pytest.raises(RuntimeError):
|
}), pytest.raises(RuntimeError):
|
||||||
LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
|
||||||
|
|||||||
@ -11,10 +11,10 @@ from torch.library import Library
|
|||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||||
set_current_vllm_config)
|
VllmConfig, set_current_vllm_config)
|
||||||
from vllm.envs import VLLM_USE_V1
|
from vllm.envs import VLLM_USE_V1
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
global_counter = 0
|
global_counter = 0
|
||||||
@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor):
|
|||||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||||
num_cudagraph_captured=
|
num_cudagraph_captured=
|
||||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
), set_forward_context({}, vllm_config=vllm_config):
|
), set_forward_context(None,
|
||||||
|
vllm_config=vllm_config): # background context
|
||||||
|
# warm up with background context
|
||||||
model(inputs)
|
model(inputs)
|
||||||
|
|
||||||
model(torch.randn(2).cuda())
|
# capturing/replaying should under context of cudagraph dispatching
|
||||||
model(torch.randn(1).cuda())
|
with set_forward_context(
|
||||||
|
None,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||||
|
model(torch.randn(2).cuda())
|
||||||
|
with set_forward_context(
|
||||||
|
None,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
batch_descriptor=BatchDescriptor(num_tokens=1, )):
|
||||||
|
model(torch.randn(1).cuda())
|
||||||
|
|
||||||
input = torch.zeros(2).cuda()
|
input = torch.zeros(2).cuda()
|
||||||
global global_counter
|
global global_counter
|
||||||
global_counter = 0
|
global_counter = 0
|
||||||
output = model(input)
|
with set_forward_context(
|
||||||
|
None,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||||
|
output = model(input)
|
||||||
assert global_counter == 2
|
assert global_counter == 2
|
||||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||||
|
|||||||
@ -18,9 +18,9 @@ from torch.library import Library
|
|||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||||
set_current_vllm_config)
|
VllmConfig, set_current_vllm_config)
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
# create a library to hold the custom op
|
# create a library to hold the custom op
|
||||||
@ -276,9 +276,11 @@ def run_model(llama_config,
|
|||||||
)
|
)
|
||||||
if split_attn:
|
if split_attn:
|
||||||
compilation_config.splitting_ops = ["silly.attention"]
|
compilation_config.splitting_ops = ["silly.attention"]
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
else:
|
else:
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
level=CompilationLevel.NO_COMPILATION, )
|
level=CompilationLevel.NO_COMPILATION, )
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=compilation_config,
|
vllm_config = VllmConfig(compilation_config=compilation_config,
|
||||||
additional_config=llama_config)
|
additional_config=llama_config)
|
||||||
@ -287,17 +289,37 @@ def run_model(llama_config,
|
|||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
prefix="").eval().cuda()
|
prefix="").eval().cuda()
|
||||||
|
|
||||||
with set_forward_context({}, vllm_config=vllm_config):
|
with set_forward_context({},
|
||||||
|
vllm_config=vllm_config): # background context
|
||||||
B = 16 # max batch size
|
B = 16 # max batch size
|
||||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||||
positions = torch.arange(B).cuda()
|
positions = torch.arange(B).cuda()
|
||||||
|
|
||||||
|
# warmup for the model with cudagraph_mode NONE
|
||||||
model(input_ids, positions)
|
model(input_ids, positions)
|
||||||
model(input_ids[:2], positions[:2])
|
|
||||||
model(input_ids[:1], positions[:1])
|
# simulate cudagraphs capturing
|
||||||
|
with set_forward_context({},
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2, )):
|
||||||
|
model(input_ids[:2], positions[:2])
|
||||||
|
with set_forward_context({},
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=1, )):
|
||||||
|
model(input_ids[:1], positions[:1])
|
||||||
|
|
||||||
input_ids[:2].zero_()
|
input_ids[:2].zero_()
|
||||||
output = model(input_ids[:2], positions[:2])
|
# simulate cudagraphs replay
|
||||||
|
with set_forward_context({},
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=BatchDescriptor(
|
||||||
|
num_tokens=2, )):
|
||||||
|
output = model(input_ids[:2], positions[:2])
|
||||||
|
|
||||||
output = output.cpu()
|
output = output.cpu()
|
||||||
|
|
||||||
|
|||||||
@ -9,10 +9,10 @@ import torch
|
|||||||
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
|
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
NUM_HEADS = [(4, 4), (8, 2)]
|
||||||
HEAD_SIZES = [128, 256]
|
HEAD_SIZES = [128, 256]
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16]
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
DTYPES = [torch.bfloat16]
|
||||||
QDTYPES = [None]
|
QDTYPES = [None]
|
||||||
# one value large enough to test overflow in index calculation.
|
# one value large enough to test overflow in index calculation.
|
||||||
# one value small enough to test the schema op check
|
# one value small enough to test the schema op check
|
||||||
|
|||||||
@ -29,17 +29,14 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
|||||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||||
PARTITION_SIZE = 512
|
PARTITION_SIZE = 512
|
||||||
PARTITION_SIZE_ROCM = 256
|
PARTITION_SIZE_ROCM = 256
|
||||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
DTYPES = [torch.bfloat16]
|
||||||
DTYPES = [
|
|
||||||
torch.half, torch.bfloat16, torch.float
|
|
||||||
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
|
|
||||||
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||||
|
|
||||||
# This should be sync with get_supported_head_sizes() in
|
# This should be sync with get_supported_head_sizes() in
|
||||||
# vllm.attention.ops.paged_attn.PagedAttention
|
# vllm.attention.ops.paged_attn.PagedAttention
|
||||||
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
HEAD_SIZES = [32, 80, 128, 256]
|
||||||
|
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
USE_ALIBI = [False, True]
|
USE_ALIBI = [False, True]
|
||||||
|
|||||||
@ -11,11 +11,11 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.bfloat16, torch.float]
|
||||||
NUM_TOKENS = [42] # Arbitrary values for testing
|
NUM_TOKENS = [42] # Arbitrary values for testing
|
||||||
NUM_LAYERS = [1] # Arbitrary values for testing
|
NUM_LAYERS = [1] # Arbitrary values for testing
|
||||||
NUM_HEADS = [8] # Arbitrary values for testing
|
NUM_HEADS = [8] # Arbitrary values for testing
|
||||||
HEAD_SIZES = [64, 80, 120, 256]
|
HEAD_SIZES = [64, 80, 256]
|
||||||
BLOCK_SIZES = [8, 16, 32]
|
BLOCK_SIZES = [8, 16, 32]
|
||||||
CACHE_LAYOUTS = ["NHD", "HND"]
|
CACHE_LAYOUTS = ["NHD", "HND"]
|
||||||
|
|
||||||
|
|||||||
@ -12,14 +12,16 @@ from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
|||||||
flash_attn_with_kvcache,
|
flash_attn_with_kvcache,
|
||||||
is_fa_version_supported)
|
is_fa_version_supported)
|
||||||
|
|
||||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
NUM_HEADS = [(4, 4), (8, 2)]
|
||||||
HEAD_SIZES = [128, 256]
|
HEAD_SIZES = [128, 256]
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16]
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
DTYPES = [torch.bfloat16]
|
||||||
QDTYPES = [None, torch.float8_e4m3fn]
|
QDTYPES = [None, torch.float8_e4m3fn]
|
||||||
# one value large enough to test overflow in index calculation.
|
# one value large enough to test overflow in index calculation.
|
||||||
# one value small enough to test the schema op check
|
# one value small enough to test the schema op check
|
||||||
NUM_BLOCKS = [32768, 2048]
|
NUM_BLOCKS = [32768, 2048]
|
||||||
|
SOFT_CAPS = [None, 50.0]
|
||||||
|
SLIDING_WINDOWS = [None, 256]
|
||||||
|
|
||||||
|
|
||||||
def ref_paged_attn(
|
def ref_paged_attn(
|
||||||
@ -83,9 +85,9 @@ def ref_paged_attn(
|
|||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -198,9 +200,9 @@ def test_flash_attn_with_paged_kv(
|
|||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||||
|
|||||||
@ -9,11 +9,13 @@ import torch
|
|||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
|
NUM_HEADS = [(32, 8), (6, 1)]
|
||||||
HEAD_SIZES = [128, 256]
|
HEAD_SIZES = [128, 256]
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
DTYPES = [torch.bfloat16]
|
||||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||||
|
SOFT_CAPS = [None, 30.0]
|
||||||
|
SLIDING_WINDOWS = [None, 64]
|
||||||
|
|
||||||
|
|
||||||
def ref_paged_attn(
|
def ref_paged_attn(
|
||||||
@ -76,8 +78,8 @@ def ref_paged_attn(
|
|||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||||
@pytest.mark.parametrize("sliding_window", [None, 64])
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_decode_with_paged_kv(
|
def test_flashinfer_decode_with_paged_kv(
|
||||||
kv_lens: list[int],
|
kv_lens: list[int],
|
||||||
@ -173,8 +175,8 @@ def test_flashinfer_decode_with_paged_kv(
|
|||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||||
@pytest.mark.parametrize("sliding_window", [None, 64])
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_prefill_with_paged_kv(
|
def test_flashinfer_prefill_with_paged_kv(
|
||||||
seq_lens: list[tuple[int, int]],
|
seq_lens: list[tuple[int, int]],
|
||||||
@ -278,11 +280,11 @@ def test_flashinfer_prefill_with_paged_kv(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
|
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
|
||||||
@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)])
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||||
def test_flashinfer_prefill_with_paged_fp8_kv(
|
def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||||
seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
|
seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
|
||||||
head_size: int, dtype: torch.dtype, block_size: int,
|
head_size: int, dtype: torch.dtype, block_size: int,
|
||||||
@ -385,11 +387,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||||
@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)])
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||||
|
@pytest.mark.skip(reason="TODO: fix the accuracy issue")
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_decode_with_paged_fp8_kv(
|
def test_flashinfer_decode_with_paged_fp8_kv(
|
||||||
kv_lens: list[int],
|
kv_lens: list[int],
|
||||||
@ -399,7 +402,6 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
soft_cap: Optional[float],
|
soft_cap: Optional[float],
|
||||||
) -> None:
|
) -> None:
|
||||||
pytest.skip("TODO: fix the accuracy issue")
|
|
||||||
# test doesn't work for num_heads = (16,16)
|
# test doesn't work for num_heads = (16,16)
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|||||||
@ -20,11 +20,11 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
|||||||
MAX_Q_LEN = 1024
|
MAX_Q_LEN = 1024
|
||||||
MAX_KV_LEN = 4096
|
MAX_KV_LEN = 4096
|
||||||
BATCH_SIZES = [4, 12]
|
BATCH_SIZES = [4, 12]
|
||||||
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
|
NUM_HEADS = [(16, 16), (40, 8)]
|
||||||
HEAD_SIZES = [128]
|
HEAD_SIZES = [128]
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16]
|
||||||
KV_LAYOUTS = ["HND"]
|
KV_LAYOUTS = ["HND"]
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
DTYPES = [torch.bfloat16]
|
||||||
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
|
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
|
||||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||||
SOFT_CAPS = [None, 50.0]
|
SOFT_CAPS = [None, 50.0]
|
||||||
|
|||||||
@ -19,13 +19,13 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
NUM_HEADS = [64]
|
NUM_HEADS = [64]
|
||||||
NUM_QUERIES_PER_KV = [1, 8, 64]
|
NUM_QUERIES_PER_KV = [1, 64]
|
||||||
HEAD_SIZES = [128, 96, 24]
|
HEAD_SIZES = [24, 128]
|
||||||
DTYPES = [torch.float16]
|
DTYPES = [torch.float16]
|
||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
]
|
]
|
||||||
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
|
SLIDING_WINDOW = [0, 16, 2048]
|
||||||
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
||||||
|
|
||||||
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
|
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
|
||||||
|
|||||||
@ -9,11 +9,11 @@ import torch
|
|||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
NUM_HEADS = [(4, 4), (8, 2)]
|
||||||
HEAD_SIZES = [128, 256]
|
HEAD_SIZES = [128, 256]
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16]
|
||||||
|
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
DTYPES = [torch.bfloat16]
|
||||||
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
|
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
|
||||||
None, torch.float8_e4m3fnuz
|
None, torch.float8_e4m3fnuz
|
||||||
]
|
]
|
||||||
@ -85,7 +85,7 @@ def ref_paged_attn(
|
|||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from einops import rearrange, repeat
|
|||||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||||
mamba_chunk_scan_combined)
|
mamba_chunk_scan_combined)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.mamba_attn import (
|
from vllm.v1.attention.backends.mamba2_attn import (
|
||||||
_query_start_loc_to_chunk_indices_offsets)
|
_query_start_loc_to_chunk_indices_offsets)
|
||||||
|
|
||||||
# Added by the IBM Team, 2024
|
# Added by the IBM Team, 2024
|
||||||
|
|||||||
@ -89,14 +89,11 @@ class BatchedMMTensors:
|
|||||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_experts", [8, 16, 32])
|
@pytest.mark.parametrize("num_experts", [8, 32])
|
||||||
@pytest.mark.parametrize("max_tokens_per_expert",
|
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
|
||||||
[32, 64, 128, 192, 224, 256, 512])
|
@pytest.mark.parametrize("K", [128, 1024])
|
||||||
@pytest.mark.parametrize("K", [128, 256, 1024])
|
@pytest.mark.parametrize("N", [128, 1024])
|
||||||
@pytest.mark.parametrize("N", [128, 256, 1024])
|
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"dtype",
|
|
||||||
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
|
|
||||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||||
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||||
|
|||||||
@ -113,8 +113,7 @@ def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
|||||||
rtol=0)
|
rtol=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317])
|
||||||
"num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317])
|
|
||||||
@pytest.mark.parametrize("num_topk", [2, 6, 8])
|
@pytest.mark.parametrize("num_topk", [2, 6, 8])
|
||||||
@pytest.mark.parametrize("num_experts", [64])
|
@pytest.mark.parametrize("num_experts", [64])
|
||||||
@pytest.mark.parametrize("ep_size", [1, 2, 4])
|
@pytest.mark.parametrize("ep_size", [1, 2, 4])
|
||||||
@ -126,7 +125,7 @@ def test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
|||||||
ep_size, topk_ids_dtype)
|
ep_size, topk_ids_dtype)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("numel", list(range(1, 8192, 11)))
|
@pytest.mark.parametrize("numel", list(range(1, 8192, 111)))
|
||||||
@pytest.mark.parametrize("num_experts", [32])
|
@pytest.mark.parametrize("num_experts", [32])
|
||||||
@pytest.mark.parametrize("ep_size", [2])
|
@pytest.mark.parametrize("ep_size", [2])
|
||||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||||
|
|||||||
@ -24,8 +24,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|||||||
fused_topk, modular_triton_fused_moe)
|
fused_topk, modular_triton_fused_moe)
|
||||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||||
fused_moe as iterative_moe)
|
fused_moe as iterative_moe)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
marlin_permute_bias)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
rand_marlin_weight_fp4_like)
|
rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
marlin_quant_fp8_torch)
|
marlin_quant_fp8_torch)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
@ -40,6 +42,24 @@ NUM_EXPERTS = [8, 64, 192]
|
|||||||
EP_SIZE = [1, 4]
|
EP_SIZE = [1, 4]
|
||||||
TOP_KS = [2, 6]
|
TOP_KS = [2, 6]
|
||||||
|
|
||||||
|
FUSED_MOE_MNK_FACTORS = [
|
||||||
|
(1, 128, 128),
|
||||||
|
(1, 2048, 128),
|
||||||
|
(33, 2048, 128),
|
||||||
|
(222, 1024, 1024),
|
||||||
|
(32768, 128, 128),
|
||||||
|
(32768, 2048, 511),
|
||||||
|
(40000, 1024, 1024),
|
||||||
|
]
|
||||||
|
|
||||||
|
FUSED_MOE_WN16_MNK_FACTORS = [
|
||||||
|
(1, 128, 128),
|
||||||
|
(1, 1024, 1024),
|
||||||
|
(32, 2048, 128),
|
||||||
|
(32, 1024, 1024),
|
||||||
|
(222, 2048, 1024),
|
||||||
|
]
|
||||||
|
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.scheduler_config.max_num_seqs = 128
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
vllm_config.scheduler_config.max_model_len = 8192
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
@ -114,13 +134,11 @@ def run_moe_test(
|
|||||||
return baseline_output
|
return baseline_output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
|
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
|
||||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
@pytest.mark.parametrize("padding", [True, False])
|
@pytest.mark.parametrize("padding", [True, False])
|
||||||
@pytest.mark.parametrize("chunk_size", [8192])
|
@pytest.mark.parametrize("chunk_size", [8192])
|
||||||
def test_fused_moe(
|
def test_fused_moe(
|
||||||
@ -233,13 +251,11 @@ def test_fused_moe(
|
|||||||
use_cudagraph=use_cudagraph)
|
use_cudagraph=use_cudagraph)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 32, 222])
|
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
|
||||||
@pytest.mark.parametrize("k", [128, 1024])
|
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
@pytest.mark.parametrize("group_size", [64, 128])
|
@pytest.mark.parametrize("group_size", [64, 128])
|
||||||
@pytest.mark.parametrize("has_zp", [True, False])
|
@pytest.mark.parametrize("has_zp", [True, False])
|
||||||
@pytest.mark.parametrize("weight_bits", [4, 8])
|
@pytest.mark.parametrize("weight_bits", [4, 8])
|
||||||
@ -350,8 +366,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
[torch.float32, torch.float16, torch.bfloat16])
|
|
||||||
@pytest.mark.parametrize("padding", [True, False])
|
@pytest.mark.parametrize("padding", [True, False])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||||
@ -476,8 +491,11 @@ def marlin_moe_generate_valid_test_cases():
|
|||||||
if quant_type == scalar_types.float8_e4m3fn and \
|
if quant_type == scalar_types.float8_e4m3fn and \
|
||||||
group_size not in [-1, 128]:
|
group_size not in [-1, 128]:
|
||||||
return False
|
return False
|
||||||
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
|
if quant_type == scalar_types.float4_e2m1f:
|
||||||
return False
|
if group_size not in [16, 32]:
|
||||||
|
return False
|
||||||
|
if dtype == torch.float16 and group_size == 32:
|
||||||
|
return False
|
||||||
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -520,31 +538,6 @@ def test_fused_marlin_moe(
|
|||||||
torch.cuda.manual_seed(0)
|
torch.cuda.manual_seed(0)
|
||||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||||
|
|
||||||
if quant_type == scalar_types.float8_e4m3fn:
|
|
||||||
if group_size not in [-1, 128]:
|
|
||||||
return
|
|
||||||
if act_order:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Filter act_order
|
|
||||||
if act_order:
|
|
||||||
if quant_type == scalar_types.float8_e4m3fn:
|
|
||||||
return
|
|
||||||
if group_size == -1:
|
|
||||||
return
|
|
||||||
if group_size in (k, n):
|
|
||||||
return
|
|
||||||
if has_zp:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
if not is_k_full:
|
|
||||||
return
|
|
||||||
|
|
||||||
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
|
|
||||||
return
|
|
||||||
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
|
||||||
return
|
|
||||||
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||||
@ -569,13 +562,19 @@ def test_fused_marlin_moe(
|
|||||||
|
|
||||||
for i in range(w1.shape[0]):
|
for i in range(w1.shape[0]):
|
||||||
if quant_type == scalar_types.float4_e2m1f:
|
if quant_type == scalar_types.float4_e2m1f:
|
||||||
w_ref1, qweight1, scales1, global_scale1 = \
|
if group_size == 16:
|
||||||
rand_marlin_weight_fp4_like(w1[i], group_size)
|
w_ref1, qweight1, scales1, global_scale1 = \
|
||||||
|
rand_marlin_weight_nvfp4_like(w1[i], group_size)
|
||||||
|
else:
|
||||||
|
w_ref1, qweight1, scales1 = \
|
||||||
|
rand_marlin_weight_mxfp4_like(w1[i], group_size)
|
||||||
|
global_scale1 = None
|
||||||
|
|
||||||
w_ref1_l.append(w_ref1.T)
|
w_ref1_l.append(w_ref1.T)
|
||||||
qweight1_l.append(qweight1)
|
qweight1_l.append(qweight1)
|
||||||
scales1_l.append(scales1)
|
scales1_l.append(scales1)
|
||||||
global_scale1_l.append(global_scale1)
|
if global_scale1 is not None:
|
||||||
|
global_scale1_l.append(global_scale1)
|
||||||
elif quant_type == scalar_types.float8_e4m3fn:
|
elif quant_type == scalar_types.float8_e4m3fn:
|
||||||
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
|
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
|
||||||
w1[i], group_size)
|
w1[i], group_size)
|
||||||
@ -620,13 +619,19 @@ def test_fused_marlin_moe(
|
|||||||
|
|
||||||
for i in range(w2.shape[0]):
|
for i in range(w2.shape[0]):
|
||||||
if quant_type == scalar_types.float4_e2m1f:
|
if quant_type == scalar_types.float4_e2m1f:
|
||||||
w_ref2, qweight2, scales2, global_scale2 = \
|
if group_size == 16:
|
||||||
rand_marlin_weight_fp4_like(w2[i], group_size)
|
w_ref2, qweight2, scales2, global_scale2 = \
|
||||||
|
rand_marlin_weight_nvfp4_like(w2[i], group_size)
|
||||||
|
else:
|
||||||
|
w_ref2, qweight2, scales2 = \
|
||||||
|
rand_marlin_weight_mxfp4_like(w2[i], group_size)
|
||||||
|
global_scale2 = None
|
||||||
|
|
||||||
w_ref2_l.append(w_ref2.T)
|
w_ref2_l.append(w_ref2.T)
|
||||||
qweight2_l.append(qweight2)
|
qweight2_l.append(qweight2)
|
||||||
scales2_l.append(scales2)
|
scales2_l.append(scales2)
|
||||||
global_scale2_l.append(global_scale2)
|
if global_scale2 is not None:
|
||||||
|
global_scale2_l.append(global_scale2)
|
||||||
elif quant_type == scalar_types.float8_e4m3fn:
|
elif quant_type == scalar_types.float8_e4m3fn:
|
||||||
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
|
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
|
||||||
w2[i], group_size)
|
w2[i], group_size)
|
||||||
@ -677,6 +682,8 @@ def test_fused_marlin_moe(
|
|||||||
a,
|
a,
|
||||||
qweight1,
|
qweight1,
|
||||||
qweight2,
|
qweight2,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
scales1,
|
scales1,
|
||||||
scales2,
|
scales2,
|
||||||
score,
|
score,
|
||||||
@ -698,6 +705,119 @@ def test_fused_marlin_moe(
|
|||||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(reruns=2)
|
||||||
|
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||||
|
@pytest.mark.parametrize("m", [1, 256])
|
||||||
|
def test_fused_marlin_moe_with_bias(m):
|
||||||
|
torch.cuda.manual_seed(0)
|
||||||
|
|
||||||
|
e, topk = 32, 4
|
||||||
|
n, k = 2048, 2048
|
||||||
|
group_size = 128
|
||||||
|
act_order = False
|
||||||
|
is_k_full = True
|
||||||
|
quant_type = scalar_types.uint4b8
|
||||||
|
dtype = torch.half
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
|
||||||
|
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
|
||||||
|
|
||||||
|
b_bias1_l = []
|
||||||
|
w_ref1_l = []
|
||||||
|
qweight1_l = []
|
||||||
|
scales1_l = []
|
||||||
|
g_idx1_l = []
|
||||||
|
sort_indices1_l = []
|
||||||
|
|
||||||
|
for i in range(w1.shape[0]):
|
||||||
|
test_perm = torch.randperm(k)
|
||||||
|
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
|
||||||
|
marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||||
|
group_size, act_order, test_perm)
|
||||||
|
|
||||||
|
w_ref1_l.append(w_ref1.T)
|
||||||
|
qweight1_l.append(qweight1)
|
||||||
|
scales1_l.append(scales1)
|
||||||
|
g_idx1_l.append(g_idx1)
|
||||||
|
sort_indices1_l.append(sort_indices1)
|
||||||
|
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
|
||||||
|
|
||||||
|
w_ref1 = stack_and_dev(w_ref1_l)
|
||||||
|
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||||
|
scales1 = stack_and_dev(scales1_l)
|
||||||
|
global_scale1 = None
|
||||||
|
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
||||||
|
zeros1 = None
|
||||||
|
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
||||||
|
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
|
||||||
|
|
||||||
|
b_bias2_l = []
|
||||||
|
w_ref2_l = []
|
||||||
|
qweight2_l = []
|
||||||
|
scales2_l = []
|
||||||
|
g_idx2_l = []
|
||||||
|
sort_indices2_l = []
|
||||||
|
|
||||||
|
for i in range(w2.shape[0]):
|
||||||
|
test_perm = torch.randperm(n)
|
||||||
|
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
|
||||||
|
marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||||
|
group_size, act_order, test_perm)
|
||||||
|
|
||||||
|
w_ref2_l.append(w_ref2.T)
|
||||||
|
qweight2_l.append(qweight2)
|
||||||
|
scales2_l.append(scales2)
|
||||||
|
g_idx2_l.append(g_idx2)
|
||||||
|
sort_indices2_l.append(sort_indices2)
|
||||||
|
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
|
||||||
|
|
||||||
|
w_ref2 = stack_and_dev(w_ref2_l)
|
||||||
|
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||||
|
scales2 = stack_and_dev(scales2_l)
|
||||||
|
global_scale2 = None
|
||||||
|
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
||||||
|
zeros2 = None
|
||||||
|
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
||||||
|
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
|
||||||
|
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
|
||||||
|
b_bias2)
|
||||||
|
|
||||||
|
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||||
|
a,
|
||||||
|
qweight1,
|
||||||
|
qweight2,
|
||||||
|
marlin_bias1,
|
||||||
|
marlin_bias2,
|
||||||
|
scales1,
|
||||||
|
scales2,
|
||||||
|
score,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=None,
|
||||||
|
global_scale1=global_scale1,
|
||||||
|
global_scale2=global_scale2,
|
||||||
|
g_idx1=g_idx1,
|
||||||
|
g_idx2=g_idx2,
|
||||||
|
sort_indices1=sort_indices1,
|
||||||
|
sort_indices2=sort_indices2,
|
||||||
|
w1_zeros=zeros1,
|
||||||
|
w2_zeros=zeros2,
|
||||||
|
quant_type_id=quant_type.id,
|
||||||
|
is_k_full=is_k_full)
|
||||||
|
|
||||||
|
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
def test_moe_align_block_size_opcheck():
|
def test_moe_align_block_size_opcheck():
|
||||||
num_experts = 4
|
num_experts = 4
|
||||||
block_size = 4
|
block_size = 4
|
||||||
|
|||||||
@ -15,10 +15,10 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import round_up
|
from vllm.utils import round_up
|
||||||
|
|
||||||
NUM_TOKENS = [1, 3, 7, 16, 256, 2256, 4096]
|
NUM_TOKENS = [1, 3, 256, 2256, 4096]
|
||||||
NUM_EXPERTS = [32, 160, 256, 257, 512]
|
NUM_EXPERTS = [32, 160, 256, 257]
|
||||||
TOP_KS = [1, 2, 16, 32]
|
TOP_KS = [1, 2, 16, 32]
|
||||||
BLOCK_SIZES = [32, 64, 128, 256]
|
BLOCK_SIZES = [32, 128]
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
NUM_EXPERTS = [16, 64, 256]
|
NUM_EXPERTS = [16, 64, 256]
|
||||||
TOP_KS = [2, 4, 6, 8]
|
TOP_KS = [2, 6, 8]
|
||||||
EP_SIZE = [1, 4, 16]
|
EP_SIZE = [1, 4, 16]
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
@ -177,11 +177,11 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000])
|
@pytest.mark.parametrize("n_token", [1, 33, 1024, 5000])
|
||||||
@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168])
|
@pytest.mark.parametrize("n_hidden", [2048, 7168])
|
||||||
@pytest.mark.parametrize("n_expert", NUM_EXPERTS)
|
@pytest.mark.parametrize("n_expert", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||||
@pytest.mark.parametrize("align_block_size", [None, 128])
|
@pytest.mark.parametrize("align_block_size", [None, 128])
|
||||||
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
||||||
|
|||||||
@ -44,6 +44,14 @@ requires_pplx = pytest.mark.skipif(
|
|||||||
reason="Requires PPLX kernels",
|
reason="Requires PPLX kernels",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BATCHED_MOE_MNK_FACTORS = [
|
||||||
|
(1, 128, 128),
|
||||||
|
(33, 2048, 128),
|
||||||
|
(64, 128, 2048),
|
||||||
|
(222, 128, 128),
|
||||||
|
(222, 2048, 1024),
|
||||||
|
]
|
||||||
|
|
||||||
PPLX_COMBOS = [
|
PPLX_COMBOS = [
|
||||||
# TODO: figure out why this fails, seems to be test problem
|
# TODO: figure out why this fails, seems to be test problem
|
||||||
#(1, 128, 128),
|
#(1, 128, 128),
|
||||||
@ -152,9 +160,7 @@ def torch_batched_moe(
|
|||||||
return torch_finalize(out, topk_weight, topk_ids)
|
return torch_finalize(out, topk_weight, topk_ids)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
|
||||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
|||||||
139
tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Normal file
139
tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
|
||||||
|
convert_swizzled_to_linear, dequantize_nvfp4_to_dtype)
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm
|
||||||
|
|
||||||
|
if not current_platform.has_device_capability(100):
|
||||||
|
pytest.skip(
|
||||||
|
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
# m, n, k
|
||||||
|
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||||
|
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||||
|
SHAPES.extend(PAD_SHAPES)
|
||||||
|
|
||||||
|
SEEDS = [42]
|
||||||
|
CUDA_DEVICES = ["cuda:0"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_ref_results(
|
||||||
|
a_fp4,
|
||||||
|
b_fp4,
|
||||||
|
a_sf,
|
||||||
|
b_sf,
|
||||||
|
a_global_scale,
|
||||||
|
b_global_scale,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
dtype,
|
||||||
|
block_size,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
_, m_k = a_fp4.shape
|
||||||
|
_, n_k = b_fp4.shape
|
||||||
|
assert m_k == n_k
|
||||||
|
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
||||||
|
a_sf,
|
||||||
|
a_global_scale,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
block_size=block_size)
|
||||||
|
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
|
||||||
|
b_sf,
|
||||||
|
b_global_scale,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
block_size=block_size)
|
||||||
|
return torch.matmul(a_in_dtype, b_in_dtype.t())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("shape", SHAPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("backend", ["cutlass", "trtllm"])
|
||||||
|
@pytest.mark.parametrize("autotune", [False, True])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_flashinfer_nvfp4_gemm(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
shape: tuple[int, int, int],
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
backend: str,
|
||||||
|
autotune: bool,
|
||||||
|
) -> None:
|
||||||
|
if backend == "trtllm" and dtype == torch.float16:
|
||||||
|
pytest.skip(
|
||||||
|
"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
|
||||||
|
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
m, n, packed_k = shape
|
||||||
|
k = packed_k * 2
|
||||||
|
block_size = 16
|
||||||
|
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
|
||||||
|
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||||
|
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||||
|
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||||
|
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||||
|
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||||
|
# ops.scaled_fp4_quant returns swizzled scales, while weights
|
||||||
|
# from checkpoints are in linear scales.
|
||||||
|
# So instead of needing to swizzle for cutlass as in modelopt.py,
|
||||||
|
# we need to unswizzle for trtllm here.
|
||||||
|
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||||
|
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||||
|
|
||||||
|
# get_ref_results unswizzles the scales internally.
|
||||||
|
expected_out = get_ref_results(
|
||||||
|
a_fp4,
|
||||||
|
b_fp4,
|
||||||
|
a_scale_interleaved,
|
||||||
|
b_scale_interleaved,
|
||||||
|
a_global_scale,
|
||||||
|
b_global_scale,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
dtype,
|
||||||
|
block_size,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
|
import flashinfer
|
||||||
|
|
||||||
|
if backend == "trtllm":
|
||||||
|
epilogue_tile_m = 128
|
||||||
|
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8),
|
||||||
|
epilogue_tile_m)
|
||||||
|
|
||||||
|
b_scale_interleaved = convert_swizzled_to_linear(
|
||||||
|
b_scale_interleaved, n, k, block_size)
|
||||||
|
b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a(
|
||||||
|
b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape(
|
||||||
|
b_scale_interleaved.shape).view(torch.float8_e4m3fn))
|
||||||
|
|
||||||
|
with flashinfer.autotune(autotune):
|
||||||
|
out = flashinfer_scaled_fp4_mm(
|
||||||
|
a_fp4,
|
||||||
|
b_fp4,
|
||||||
|
a_scale_interleaved,
|
||||||
|
b_scale_interleaved,
|
||||||
|
alpha,
|
||||||
|
dtype,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out,
|
||||||
|
expected_out.to(dtype=dtype),
|
||||||
|
atol=1e-1,
|
||||||
|
rtol=1e-1)
|
||||||
@ -11,11 +11,9 @@ from tests.kernels.quant_utils import (FP8_DTYPE,
|
|||||||
from tests.kernels.utils import opcheck
|
from tests.kernels.utils import opcheck
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.bfloat16, torch.float]
|
||||||
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
|
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
|
||||||
8193] # Arbitrary values for testing
|
NUM_TOKENS = [1, 7, 4096]
|
||||||
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
|
|
||||||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
|
||||||
SCALE_UBS = [True, False]
|
SCALE_UBS = [True, False]
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|||||||
@ -9,10 +9,9 @@ from tests.kernels.utils import opcheck
|
|||||||
from vllm._custom_ops import scaled_int8_quant
|
from vllm._custom_ops import scaled_int8_quant
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.bfloat16, torch.float]
|
||||||
HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing
|
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
|
||||||
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
|
NUM_TOKENS = [1, 7, 4096]
|
||||||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
SCALE = [0.1, 2.1]
|
SCALE = [0.1, 2.1]
|
||||||
|
|
||||||
|
|||||||
@ -34,8 +34,6 @@ IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
|||||||
|
|
||||||
MNK_SHAPES = [
|
MNK_SHAPES = [
|
||||||
(1, 128, 128),
|
(1, 128, 128),
|
||||||
(1, 512, 1024),
|
|
||||||
(1, 4096, 4096),
|
|
||||||
(1, 8192, 28672),
|
(1, 8192, 28672),
|
||||||
(13, 8192, 4096),
|
(13, 8192, 4096),
|
||||||
(26, 4096, 8192),
|
(26, 4096, 8192),
|
||||||
@ -43,8 +41,6 @@ MNK_SHAPES = [
|
|||||||
(64, 8192, 28672),
|
(64, 8192, 28672),
|
||||||
(257, 128, 4096),
|
(257, 128, 4096),
|
||||||
(257, 4224, 4160),
|
(257, 4224, 4160),
|
||||||
(257, 4096, 4096),
|
|
||||||
(1024, 4096, 8192),
|
|
||||||
(1024, 8192, 4096),
|
(1024, 8192, 4096),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -19,10 +19,11 @@ from vllm.model_executor.layers.quantization.qqq import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||||
marlin_make_workspace_new, marlin_permute_scales,
|
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
|
||||||
query_marlin_supported_quant_types)
|
query_marlin_supported_quant_types)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like)
|
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
|
||||||
|
rand_marlin_weight_nvfp4_like)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
marlin_quant_fp8_torch)
|
marlin_quant_fp8_torch)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
@ -39,7 +40,7 @@ from vllm.scalar_type import scalar_types
|
|||||||
ACT_ORDER_OPTS = [False, True]
|
ACT_ORDER_OPTS = [False, True]
|
||||||
K_FULL_OPTS = [False, True]
|
K_FULL_OPTS = [False, True]
|
||||||
USE_ATOMIC_ADD_OPTS = [False, True]
|
USE_ATOMIC_ADD_OPTS = [False, True]
|
||||||
USE_FP32_REDUCE_OPTS = [False, True]
|
USE_FP32_REDUCE_OPTS = [True]
|
||||||
|
|
||||||
MARLIN_K_CHUNKS = [128]
|
MARLIN_K_CHUNKS = [128]
|
||||||
MARLIN_N_CHUNKS = [64, 256]
|
MARLIN_N_CHUNKS = [64, 256]
|
||||||
@ -52,12 +53,8 @@ HQQ_SUPPORTED_GROUP_SIZES = [64]
|
|||||||
MNK_FACTORS = [
|
MNK_FACTORS = [
|
||||||
(1, 1, 1),
|
(1, 1, 1),
|
||||||
(1, 4, 8),
|
(1, 4, 8),
|
||||||
(1, 7, 5),
|
|
||||||
(13, 17, 67),
|
|
||||||
(26, 37, 13),
|
(26, 37, 13),
|
||||||
(67, 13, 11),
|
|
||||||
(257, 13, 11),
|
(257, 13, 11),
|
||||||
(658, 13, 11),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
@ -202,17 +199,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
|||||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||||
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
|
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
|
||||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||||
def test_gptq_marlin_gemm(
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
k_chunk,
|
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||||
n_chunk,
|
mnk_factors, act_order, is_k_full, use_atomic_add,
|
||||||
quant_type,
|
use_fp32_reduce, dtype):
|
||||||
group_size,
|
|
||||||
mnk_factors,
|
|
||||||
act_order,
|
|
||||||
is_k_full,
|
|
||||||
use_atomic_add,
|
|
||||||
use_fp32_reduce,
|
|
||||||
):
|
|
||||||
m_factor, n_factor, k_factor = mnk_factors
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||||
|
|
||||||
@ -231,14 +221,23 @@ def test_gptq_marlin_gemm(
|
|||||||
if size_k % group_size != 0:
|
if size_k % group_size != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
a_input = rand_data((size_m, size_k))
|
a_input = rand_data((size_m, size_k), dtype)
|
||||||
b_weight = rand_data((size_k, size_n))
|
b_weight = rand_data((size_k, size_n), dtype)
|
||||||
|
|
||||||
if quant_type == scalar_types.float4_e2m1f:
|
if quant_type == scalar_types.float4_e2m1f:
|
||||||
if group_size != 16 or act_order:
|
if group_size not in [16, 32] or act_order:
|
||||||
return
|
return
|
||||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
|
if group_size == 32 and dtype == torch.float16:
|
||||||
b_weight.T, group_size)
|
return
|
||||||
|
|
||||||
|
if group_size == 16:
|
||||||
|
w_ref, marlin_q_w, marlin_s, marlin_s2 = \
|
||||||
|
rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
|
||||||
|
else:
|
||||||
|
w_ref, marlin_q_w, marlin_s = \
|
||||||
|
rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
|
||||||
|
marlin_s2 = None
|
||||||
|
|
||||||
g_idx = None
|
g_idx = None
|
||||||
sort_indices = None
|
sort_indices = None
|
||||||
marlin_zp = None
|
marlin_zp = None
|
||||||
@ -272,8 +271,8 @@ def test_gptq_marlin_gemm(
|
|||||||
workspace = marlin_make_workspace_new(w_ref.device)
|
workspace = marlin_make_workspace_new(w_ref.device)
|
||||||
|
|
||||||
opcheck(torch.ops._C.gptq_marlin_gemm,
|
opcheck(torch.ops._C.gptq_marlin_gemm,
|
||||||
(a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx,
|
(a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
|
||||||
sort_indices, workspace, quant_type.id, a_input.shape[0],
|
g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
|
||||||
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
|
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
|
||||||
use_fp32_reduce, False),
|
use_fp32_reduce, False),
|
||||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||||
@ -282,6 +281,7 @@ def test_gptq_marlin_gemm(
|
|||||||
a_input,
|
a_input,
|
||||||
None,
|
None,
|
||||||
marlin_q_w,
|
marlin_q_w,
|
||||||
|
None,
|
||||||
marlin_s,
|
marlin_s,
|
||||||
marlin_s2,
|
marlin_s2,
|
||||||
marlin_zp,
|
marlin_zp,
|
||||||
@ -418,6 +418,7 @@ def test_hqq_marlin_gemm(
|
|||||||
a_input,
|
a_input,
|
||||||
None,
|
None,
|
||||||
marlin_w_q,
|
marlin_w_q,
|
||||||
|
None,
|
||||||
marlin_s,
|
marlin_s,
|
||||||
None,
|
None,
|
||||||
marlin_zp,
|
marlin_zp,
|
||||||
@ -531,6 +532,7 @@ def test_marlin_gemm_subset_input():
|
|||||||
a_input,
|
a_input,
|
||||||
None,
|
None,
|
||||||
marlin_q_w,
|
marlin_q_w,
|
||||||
|
None,
|
||||||
marlin_s,
|
marlin_s,
|
||||||
None,
|
None,
|
||||||
marlin_zp,
|
marlin_zp,
|
||||||
@ -555,6 +557,53 @@ def test_marlin_gemm_subset_input():
|
|||||||
assert max_diff < 0.04
|
assert max_diff < 0.04
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("size_m", [1, 256])
|
||||||
|
def test_marlin_gemm_with_bias(size_m):
|
||||||
|
quant_type = scalar_types.uint4b8
|
||||||
|
group_size = 128
|
||||||
|
|
||||||
|
size_k, size_n = 1024, 2048
|
||||||
|
a_input = rand_data((size_m, size_k))
|
||||||
|
b_weight = rand_data((size_k, size_n))
|
||||||
|
b_bias = rand_data((size_n, )) * 10
|
||||||
|
|
||||||
|
marlin_bias = marlin_permute_bias(b_bias)
|
||||||
|
|
||||||
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||||
|
b_weight, quant_type, group_size, False)
|
||||||
|
|
||||||
|
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||||
|
workspace = marlin_make_workspace_new(a_input.device)
|
||||||
|
|
||||||
|
output = ops.gptq_marlin_gemm(
|
||||||
|
a_input,
|
||||||
|
None,
|
||||||
|
marlin_q_w,
|
||||||
|
marlin_bias,
|
||||||
|
marlin_s,
|
||||||
|
None,
|
||||||
|
marlin_zp,
|
||||||
|
g_idx,
|
||||||
|
sort_indices,
|
||||||
|
workspace,
|
||||||
|
quant_type,
|
||||||
|
a_input.shape[0],
|
||||||
|
b_weight.shape[1],
|
||||||
|
a_input.shape[1],
|
||||||
|
is_k_full=True,
|
||||||
|
use_atomic_add=False,
|
||||||
|
use_fp32_reduce=True,
|
||||||
|
is_zp_float=False,
|
||||||
|
)
|
||||||
|
output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
max_diff = compute_max_diff(output, output_ref)
|
||||||
|
|
||||||
|
assert max_diff < 0.04
|
||||||
|
|
||||||
|
|
||||||
def test_marlin_gemm_opcheck():
|
def test_marlin_gemm_opcheck():
|
||||||
size_m = 2048
|
size_m = 2048
|
||||||
size_n = 4096
|
size_n = 4096
|
||||||
|
|||||||
@ -65,9 +65,12 @@ def test_nvfp4_gemm(
|
|||||||
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||||
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
|
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||||
alpha = 1. / (a_global_scale * b_global_scale)
|
alpha = 1. / (a_global_scale * b_global_scale)
|
||||||
|
# ops.scaled_fp4_quant returns swizzled scales, while weights
|
||||||
|
# from checkpoints are in linear scales.
|
||||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||||
|
|
||||||
|
# get_ref_results unswizzles the scales internally.
|
||||||
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
|
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
|
||||||
b_scale_interleaved, a_global_scale,
|
b_scale_interleaved, a_global_scale,
|
||||||
b_global_scale, m, n, dtype, block_size,
|
b_global_scale, m, n, dtype, block_size,
|
||||||
|
|||||||
@ -8,15 +8,55 @@ from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
DTYPES = [torch.bfloat16, torch.float16]
|
DTYPES = [torch.bfloat16, torch.float16]
|
||||||
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
|
# Specific (N, K, M) combinations for targeted testing
|
||||||
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0
|
NKM_FACTORS_LLMM1 = [
|
||||||
N = [1, 2, 3, 4]
|
# Small, medium, large cases
|
||||||
|
(1, 8, 16),
|
||||||
|
(1, 32, 64),
|
||||||
|
(1, 128, 256),
|
||||||
|
(1, 512, 1024),
|
||||||
|
(1, 2048, 4096),
|
||||||
|
# Edge cases with specific K sizes
|
||||||
|
(1, 6144, 1024),
|
||||||
|
(1, 8192, 2048),
|
||||||
|
# Very large case
|
||||||
|
(1, 4096, 8192),
|
||||||
|
]
|
||||||
|
|
||||||
|
NKM_FACTORS_WVSPLITK = [
|
||||||
|
# Different batch sizes with key dimensions
|
||||||
|
(1, 16, 16),
|
||||||
|
(1, 64, 64),
|
||||||
|
(2, 256, 256),
|
||||||
|
(3, 1024, 1024),
|
||||||
|
(4, 4096, 4096),
|
||||||
|
# Extended K values
|
||||||
|
(1, 9216, 512),
|
||||||
|
(2, 10240, 1024),
|
||||||
|
(4, 16384, 8192),
|
||||||
|
# Minimum M constraint validation (m >= 8)
|
||||||
|
(1, 64, 8),
|
||||||
|
(2, 128, 8),
|
||||||
|
(4, 256, 8),
|
||||||
|
]
|
||||||
|
|
||||||
|
NKM_FACTORS_WVSPLITK_FP8 = [
|
||||||
|
# FP8-specific cases with K % 16 == 0
|
||||||
|
(1, 16, 16),
|
||||||
|
(1, 64, 64),
|
||||||
|
(2, 512, 512),
|
||||||
|
(3, 2048, 2048),
|
||||||
|
(4, 4096, 4096),
|
||||||
|
# Extended FP8 dimensions not covered by WVSPLITK
|
||||||
|
(1, 14336, 1024),
|
||||||
|
(2, 24576, 2048),
|
||||||
|
(4, 32768, 28672),
|
||||||
|
]
|
||||||
|
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n", [1]) # only test for batch size 1
|
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
|
||||||
@pytest.mark.parametrize("k", K)
|
|
||||||
@pytest.mark.parametrize("m", M)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
|
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@ -34,9 +74,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
|||||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
|
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
|
||||||
@pytest.mark.parametrize("k", K + [9216, 10240, 16384])
|
|
||||||
@pytest.mark.parametrize("m", [8] + M) # m >= 8
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||||
@ -54,9 +92,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
|||||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
|
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
|
||||||
@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0
|
|
||||||
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
|
|||||||
@ -60,10 +60,18 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
|
|||||||
num_logprobs)
|
num_logprobs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("M", [1, 33, 64, 512])
|
MNK_FACTORS = [
|
||||||
@pytest.mark.parametrize("N", [256, 971, 20486])
|
(1, 256, 128),
|
||||||
@pytest.mark.parametrize("K", [128, 496, 1024])
|
(33, 256, 496),
|
||||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
(64, 971, 1024),
|
||||||
|
(64, 20486, 128),
|
||||||
|
(512, 256, 496),
|
||||||
|
(512, 20486, 1024),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("M,N,K", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
|
||||||
@pytest.mark.parametrize("in_dtype", get_8bit_types())
|
@pytest.mark.parametrize("in_dtype", get_8bit_types())
|
||||||
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
|
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
|
||||||
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
|
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
|
||||||
|
|||||||
@ -1064,6 +1064,8 @@ def torch_experts(
|
|||||||
topk_weight: torch.Tensor,
|
topk_weight: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
|
b_bias1: Optional[torch.Tensor] = None,
|
||||||
|
b_bias2: Optional[torch.Tensor] = None,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
@ -1108,8 +1110,13 @@ def torch_experts(
|
|||||||
if mask.sum():
|
if mask.sum():
|
||||||
if quant_dtype is None:
|
if quant_dtype is None:
|
||||||
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||||
|
if b_bias1 is not None:
|
||||||
|
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
|
||||||
tmp2 = SiluAndMul()(tmp1)
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||||
|
if b_bias2 is not None:
|
||||||
|
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
|
||||||
|
tmp1.dtype)
|
||||||
elif block_shape is not None:
|
elif block_shape is not None:
|
||||||
# block quantized
|
# block quantized
|
||||||
assert (a_scale is not None and w1_scale is not None
|
assert (a_scale is not None and w1_scale is not None
|
||||||
@ -1117,6 +1124,8 @@ def torch_experts(
|
|||||||
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||||
w1_scale[i], block_shape,
|
w1_scale[i], block_shape,
|
||||||
out.dtype)
|
out.dtype)
|
||||||
|
if b_bias1 is not None:
|
||||||
|
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
|
||||||
tmp2 = SiluAndMul()(tmp1)
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
tmp2, b_scale = moe_kernel_quantize_input(
|
tmp2, b_scale = moe_kernel_quantize_input(
|
||||||
tmp2, a2_scale, quant_dtype, per_act_token_quant,
|
tmp2, a2_scale, quant_dtype, per_act_token_quant,
|
||||||
@ -1125,6 +1134,9 @@ def torch_experts(
|
|||||||
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
||||||
w2_scale[i], block_shape,
|
w2_scale[i], block_shape,
|
||||||
out.dtype)
|
out.dtype)
|
||||||
|
if b_bias2 is not None:
|
||||||
|
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
|
||||||
|
tmp1.dtype)
|
||||||
else:
|
else:
|
||||||
assert (a_scale is not None and w1_scale is not None
|
assert (a_scale is not None and w1_scale is not None
|
||||||
and w2_scale is not None)
|
and w2_scale is not None)
|
||||||
@ -1133,6 +1145,8 @@ def torch_experts(
|
|||||||
tmp1 = a[mask].to(f32) * scales
|
tmp1 = a[mask].to(f32) * scales
|
||||||
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
|
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
|
||||||
tmp1 = (tmp1 @ w1_dq).to(out.dtype)
|
tmp1 = (tmp1 @ w1_dq).to(out.dtype)
|
||||||
|
if b_bias1 is not None:
|
||||||
|
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype)
|
||||||
|
|
||||||
tmp2 = SiluAndMul()(tmp1).to(out.dtype)
|
tmp2 = SiluAndMul()(tmp1).to(out.dtype)
|
||||||
|
|
||||||
@ -1144,6 +1158,9 @@ def torch_experts(
|
|||||||
tmp2 = tmp2.to(f32) * b_scale
|
tmp2 = tmp2.to(f32) * b_scale
|
||||||
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
|
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
|
||||||
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
|
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
|
||||||
|
if b_bias2 is not None:
|
||||||
|
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
|
||||||
|
out.dtype)
|
||||||
|
|
||||||
if apply_router_weights_on_input:
|
if apply_router_weights_on_input:
|
||||||
return out
|
return out
|
||||||
@ -1157,12 +1174,14 @@ def torch_moe(a: torch.Tensor,
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
score: torch.Tensor,
|
score: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
|
b_bias1: Optional[torch.Tensor] = None,
|
||||||
|
b_bias2: Optional[torch.Tensor] = None,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||||
topk_weight, topk_ids = torch.topk(score, topk)
|
topk_weight, topk_ids = torch.topk(score, topk)
|
||||||
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
|
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
|
||||||
expert_map)
|
b_bias1, b_bias2, expert_map)
|
||||||
|
|
||||||
|
|
||||||
def torch_moe_single(a, w, score, topk):
|
def torch_moe_single(a, w, score, topk):
|
||||||
|
|||||||
@ -57,6 +57,13 @@ V1_SUPPORTED_MODELS = [
|
|||||||
# Avoid OOM
|
# Avoid OOM
|
||||||
MAX_NUM_SEQS = 4
|
MAX_NUM_SEQS = 4
|
||||||
|
|
||||||
|
# Once we add support for FCG in Mamba1, this list will be removed and tests
|
||||||
|
# all test cases will use enforce_eager=False
|
||||||
|
ENFORCE_EAGER_MODELS_V1 = [
|
||||||
|
"state-spaces/mamba-130m-hf",
|
||||||
|
"ai21labs/Jamba-tiny-dev",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@ -94,13 +101,19 @@ def test_models(
|
|||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
if model in V1_SUPPORTED_MODELS:
|
if model in V1_SUPPORTED_MODELS:
|
||||||
|
enforce_eager = False
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
if model in HYBRID_MODELS:
|
if model in HYBRID_MODELS:
|
||||||
# required due to reorder_batch behaviour
|
# required due to reorder_batch behaviour
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||||
|
|
||||||
|
if model in ENFORCE_EAGER_MODELS_V1:
|
||||||
|
enforce_eager = True
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
enable_prefix_caching=False) as vllm_model:
|
enable_prefix_caching=False) as vllm_model:
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
@ -418,3 +431,65 @@ def test_full_cuda_graph(
|
|||||||
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
||||||
name_1="vllm-v1",
|
name_1="vllm-v1",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_fp32_state(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
monkeypatch,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with hf_runner(model) as hf_model:
|
||||||
|
if model not in HF_UNSUPPORTED_MODELS:
|
||||||
|
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
else:
|
||||||
|
hf_outputs = None
|
||||||
|
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
|
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||||
|
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
if model in HYBRID_MODELS:
|
||||||
|
# required due to reorder_batch behaviour
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
|
mamba_ssm_cache_dtype="float32",
|
||||||
|
enable_prefix_caching=False) as vllm_model:
|
||||||
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
if hf_outputs is not None:
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_v0_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm-v0",
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=ref_outputs,
|
||||||
|
outputs_1_lst=vllm_v1_outputs,
|
||||||
|
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
||||||
|
name_1="vllm-v1",
|
||||||
|
)
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from tests.models.utils import EmbedModelInfo, RerankModelInfo
|
|||||||
# - Different model results in differences more than 1e-3
|
# - Different model results in differences more than 1e-3
|
||||||
# 1e-4 is a good tolerance threshold
|
# 1e-4 is a good tolerance threshold
|
||||||
MTEB_EMBED_TASKS = ["STS12"]
|
MTEB_EMBED_TASKS = ["STS12"]
|
||||||
MTEB_EMBED_TOL = 1e-4
|
MTEB_EMBED_TOL = 0.02
|
||||||
|
|
||||||
# See #19344
|
# See #19344
|
||||||
MTEB_RERANK_TASKS = ["NFCorpus"]
|
MTEB_RERANK_TASKS = ["NFCorpus"]
|
||||||
@ -175,6 +175,7 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
with vllm_runner(model_info.name,
|
with vllm_runner(model_info.name,
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
|
enforce_eager=True,
|
||||||
**vllm_extra_kwargs) as vllm_model:
|
**vllm_extra_kwargs) as vllm_model:
|
||||||
|
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
@ -198,6 +199,7 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
|
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
|
||||||
st_dtype = next(hf_model.model.parameters()).dtype
|
st_dtype = next(hf_model.model.parameters()).dtype
|
||||||
|
|
||||||
|
print("Model:", model_info.name)
|
||||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
print("VLLM:", vllm_dtype, vllm_main_score)
|
||||||
print("SentenceTransformers:", st_dtype, st_main_score)
|
print("SentenceTransformers:", st_dtype, st_main_score)
|
||||||
print("Difference:", st_main_score - vllm_main_score)
|
print("Difference:", st_main_score - vllm_main_score)
|
||||||
@ -286,6 +288,7 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
runner="pooling",
|
runner="pooling",
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
max_num_seqs=8,
|
max_num_seqs=8,
|
||||||
|
enforce_eager=True,
|
||||||
**vllm_extra_kwargs) as vllm_model:
|
**vllm_extra_kwargs) as vllm_model:
|
||||||
|
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
@ -304,6 +307,7 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
st_main_score, st_dtype = mteb_test_rerank_models_hf(
|
st_main_score, st_dtype = mteb_test_rerank_models_hf(
|
||||||
hf_runner, model_info.name, hf_model_callback)
|
hf_runner, model_info.name, hf_model_callback)
|
||||||
|
|
||||||
|
print("Model:", model_info.name)
|
||||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
print("VLLM:", vllm_dtype, vllm_main_score)
|
||||||
print("SentenceTransformers:", st_dtype, st_main_score)
|
print("SentenceTransformers:", st_dtype, st_main_score)
|
||||||
print("Difference:", st_main_score - vllm_main_score)
|
print("Difference:", st_main_score - vllm_main_score)
|
||||||
|
|||||||
@ -151,7 +151,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
|
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
|
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
|
||||||
min_transformers_version="4.55.1",
|
min_transformers_version="4.56.0",
|
||||||
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
|
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
|
||||||
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
|
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
|
||||||
{"1b": "bigscience/bloomz-1b1"}),
|
{"1b": "bigscience/bloomz-1b1"}),
|
||||||
@ -227,7 +227,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
|
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
|
||||||
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
|
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
|
||||||
min_transformers_version="4.55.1",
|
min_transformers_version="4.56.0",
|
||||||
extras={
|
extras={
|
||||||
"tiny": "ai21labs/Jamba-tiny-dev",
|
"tiny": "ai21labs/Jamba-tiny-dev",
|
||||||
"random": "ai21labs/Jamba-tiny-random", # noqa: E501
|
"random": "ai21labs/Jamba-tiny-random", # noqa: E501
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -72,3 +73,22 @@ def test_hash_non_contiguous_array():
|
|||||||
hasher = MultiModalHasher
|
hasher = MultiModalHasher
|
||||||
# Both should be hashable and produce the same hashes
|
# Both should be hashable and produce the same hashes
|
||||||
assert hasher.hash_kwargs(data=arr) == hasher.hash_kwargs(data=arr_c)
|
assert hasher.hash_kwargs(data=arr) == hasher.hash_kwargs(data=arr_c)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hash_image_exif_id():
|
||||||
|
# Test that EXIF ImageId tag can be used to store UUID
|
||||||
|
# and the hasher will use that instead of the image data.
|
||||||
|
image1 = image2 = Image.new("1", size=(10, 20))
|
||||||
|
id = uuid.uuid4()
|
||||||
|
image1.getexif()[Image.ExifTags.Base.ImageID] = id
|
||||||
|
image2 = Image.open(ASSETS_DIR / "image1.png")
|
||||||
|
image2.getexif()[Image.ExifTags.Base.ImageID] = "Not a UUID"
|
||||||
|
image2a = Image.open(ASSETS_DIR / "image1.png")
|
||||||
|
|
||||||
|
hasher = MultiModalHasher
|
||||||
|
# first image has UUID in ImageID, so it should hash to that UUID
|
||||||
|
assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs(
|
||||||
|
image=id.bytes)
|
||||||
|
# second image has non-UUID in ImageID, so it should hash to the image data
|
||||||
|
assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs(
|
||||||
|
image=image2a)
|
||||||
|
|||||||
@ -148,6 +148,32 @@ async def test_fetch_image_local_files(image_url: str):
|
|||||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_image_local_files_with_space_in_name():
|
||||||
|
image_url = TEST_IMAGE_URLS[0]
|
||||||
|
connector = MediaConnector()
|
||||||
|
|
||||||
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
|
||||||
|
|
||||||
|
origin_image = connector.fetch_image(image_url)
|
||||||
|
filename = "file name with space.jpg"
|
||||||
|
origin_image.save(os.path.join(temp_dir, filename),
|
||||||
|
quality=100,
|
||||||
|
icc_profile=origin_image.info.get('icc_profile'))
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_async = await local_connector.fetch_image_async(
|
||||||
|
f"file://{temp_dir}/{filename}")
|
||||||
|
image_sync = local_connector.fetch_image(
|
||||||
|
f"file://{temp_dir}/{filename}")
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
pytest.fail(
|
||||||
|
"Failed to fetch image with space in name: {}".format(e))
|
||||||
|
# Check that the images are equal
|
||||||
|
assert not ImageChops.difference(image_sync, image_async).getbbox()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fetch_image_error_conversion():
|
async def test_fetch_image_error_conversion():
|
||||||
connector = MediaConnector()
|
connector = MediaConnector()
|
||||||
|
|||||||
@ -10,7 +10,7 @@ cd /vllm-workspace/
|
|||||||
# uninstall vllm
|
# uninstall vllm
|
||||||
pip3 uninstall -y vllm
|
pip3 uninstall -y vllm
|
||||||
# restore the original files
|
# restore the original files
|
||||||
mv test_docs/vllm ./vllm
|
mv src/vllm ./vllm
|
||||||
|
|
||||||
# remove all compilers
|
# remove all compilers
|
||||||
apt remove --purge build-essential -y
|
apt remove --purge build-essential -y
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
tests/v1/cudagraph/__init__.py
Normal file
0
tests/v1/cudagraph/__init__.py
Normal file
406
tests/v1/cudagraph/test_cudagraph_dispatch.py
Normal file
406
tests/v1/cudagraph/test_cudagraph_dispatch.py
Normal file
@ -0,0 +1,406 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from tests.utils import create_new_process_for_each_test
|
||||||
|
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||||
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||||
|
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||||
|
ParallelConfig, SchedulerConfig, VllmConfig)
|
||||||
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||||
|
|
||||||
|
|
||||||
|
# Helper MLP for testing
|
||||||
|
class SimpleMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(10, 10)
|
||||||
|
self.fc2 = nn.Linear(10, 10)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fc2(self.fc1(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _create_vllm_config(compilation_config: CompilationConfig,
|
||||||
|
max_num_seqs: int = 8) -> MagicMock:
|
||||||
|
mock_config = MagicMock(spec=VllmConfig)
|
||||||
|
mock_config.compilation_config = compilation_config
|
||||||
|
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
|
||||||
|
mock_config.parallel_config = ParallelConfig()
|
||||||
|
|
||||||
|
# Mimic the behavior of VllmConfig.__post_init__()
|
||||||
|
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||||
|
compilation_config.set_splitting_ops_for_v1()
|
||||||
|
|
||||||
|
return mock_config
|
||||||
|
|
||||||
|
|
||||||
|
class TestCudagraphDispatcher:
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"params",
|
||||||
|
[
|
||||||
|
# Test case 0: Full CG for mixed batches, no separate routine
|
||||||
|
{
|
||||||
|
"case_id": 0,
|
||||||
|
"cudagraph_mode": "FULL",
|
||||||
|
"compilation_level": CompilationLevel.NO_COMPILATION,
|
||||||
|
},
|
||||||
|
# Test case 1: Full CG for uniform batches, piecewise for mixed
|
||||||
|
{
|
||||||
|
"case_id": 1,
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
"compilation_level": CompilationLevel.PIECEWISE,
|
||||||
|
},
|
||||||
|
# Test case 2: Full CG for uniform batches, no CG for mixed
|
||||||
|
{
|
||||||
|
"case_id": 2,
|
||||||
|
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||||
|
"compilation_level": CompilationLevel.NO_COMPILATION,
|
||||||
|
},
|
||||||
|
# Test case 3: Piecewise for all
|
||||||
|
{
|
||||||
|
"case_id": 3,
|
||||||
|
"cudagraph_mode": "PIECEWISE",
|
||||||
|
"compilation_level": CompilationLevel.PIECEWISE,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_dispatcher(self, params):
|
||||||
|
# Setup dispatcher
|
||||||
|
comp_config = CompilationConfig(
|
||||||
|
cudagraph_mode=params["cudagraph_mode"],
|
||||||
|
level=params["compilation_level"],
|
||||||
|
cudagraph_capture_sizes=[1, 8])
|
||||||
|
|
||||||
|
config = _create_vllm_config(comp_config, max_num_seqs=8)
|
||||||
|
dispatcher = CudagraphDispatcher(config)
|
||||||
|
dispatcher.initialize_cudagraph_keys(
|
||||||
|
cudagraph_mode=comp_config.cudagraph_mode,
|
||||||
|
uniform_decode_query_len=1)
|
||||||
|
|
||||||
|
# Verify the key is initialized correctly
|
||||||
|
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||||
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
|
||||||
|
else:
|
||||||
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
||||||
|
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
|
||||||
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
|
||||||
|
else:
|
||||||
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
||||||
|
|
||||||
|
# Test dispatch logic
|
||||||
|
# 1. non-uniform batch, size in cudagraph size list
|
||||||
|
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||||
|
rt_mode, key = dispatcher.dispatch(desc_full_exact)
|
||||||
|
if params["cudagraph_mode"] == "FULL":
|
||||||
|
assert rt_mode == CUDAGraphMode.FULL
|
||||||
|
assert key == desc_full_exact
|
||||||
|
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||||
|
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||||
|
assert key == desc_full_exact
|
||||||
|
else:
|
||||||
|
assert rt_mode == CUDAGraphMode.NONE
|
||||||
|
|
||||||
|
# 2. uniform decode batch, size in cudagraph size list
|
||||||
|
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
|
||||||
|
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
|
||||||
|
if params["cudagraph_mode"] == "FULL":
|
||||||
|
assert rt_mode == CUDAGraphMode.FULL
|
||||||
|
assert key == desc_uniform_exact.non_uniform
|
||||||
|
elif params["cudagraph_mode"] in [
|
||||||
|
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
|
||||||
|
]:
|
||||||
|
assert rt_mode == CUDAGraphMode.FULL
|
||||||
|
assert key == desc_uniform_exact
|
||||||
|
elif params["cudagraph_mode"] == "PIECEWISE":
|
||||||
|
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||||
|
assert key == desc_uniform_exact.non_uniform
|
||||||
|
else:
|
||||||
|
assert rt_mode == CUDAGraphMode.NONE
|
||||||
|
|
||||||
|
# 3. No key match
|
||||||
|
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
|
||||||
|
rt_mode, key = dispatcher.dispatch(desc_no_match)
|
||||||
|
assert rt_mode == CUDAGraphMode.NONE
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||||
|
class TestCUDAGraphWrapper:
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
self.vllm_config = _create_vllm_config(CompilationConfig())
|
||||||
|
self.model = SimpleMLP().to("cuda")
|
||||||
|
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
|
||||||
|
self.input_tensor = torch.randn(1, 10, device="cuda")
|
||||||
|
|
||||||
|
@create_new_process_for_each_test("spawn")
|
||||||
|
def test_capture_and_replay(self):
|
||||||
|
wrapper = CUDAGraphWrapper(self.model,
|
||||||
|
self.vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL)
|
||||||
|
batch_descriptor = BatchDescriptor(num_tokens=10)
|
||||||
|
|
||||||
|
# 0. global warmup
|
||||||
|
with set_forward_context(attn_metadata=None,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
|
batch_descriptor=None):
|
||||||
|
wrapper(self.input_tensor)
|
||||||
|
|
||||||
|
# 1. Capture
|
||||||
|
with set_forward_context(
|
||||||
|
attn_metadata=None,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
batch_descriptor=batch_descriptor),\
|
||||||
|
patch("torch.cuda.graph",
|
||||||
|
wraps=torch.cuda.graph) as mock_cuda_graph:
|
||||||
|
output1 = wrapper(self.input_tensor)
|
||||||
|
# capturing phase should generate a zero output
|
||||||
|
assert torch.allclose(output1, torch.zeros_like(output1))
|
||||||
|
mock_cuda_graph.assert_called_once()
|
||||||
|
|
||||||
|
assert batch_descriptor in wrapper.concrete_cudagraph_entries
|
||||||
|
entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
|
||||||
|
assert entry.cudagraph is not None
|
||||||
|
|
||||||
|
# 2. Replay
|
||||||
|
with set_forward_context(
|
||||||
|
attn_metadata=None,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||||
|
batch_descriptor=batch_descriptor),\
|
||||||
|
patch.object(entry.cudagraph, 'replay',
|
||||||
|
wraps=entry.cudagraph.replay) as mock_replay:
|
||||||
|
output2 = wrapper(self.input_tensor)
|
||||||
|
mock_replay.assert_called_once()
|
||||||
|
|
||||||
|
# Compare with eager output
|
||||||
|
eager_output = self.model(self.input_tensor)
|
||||||
|
torch.testing.assert_close(eager_output, output2)
|
||||||
|
|
||||||
|
@create_new_process_for_each_test("spawn")
|
||||||
|
def test_bypass_on_mode_mismatch(self):
|
||||||
|
wrapper = CUDAGraphWrapper(self.model,
|
||||||
|
self.vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL)
|
||||||
|
batch_descriptor = BatchDescriptor(num_tokens=10)
|
||||||
|
|
||||||
|
with set_forward_context(
|
||||||
|
attn_metadata=None,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
batch_descriptor=batch_descriptor), \
|
||||||
|
patch('torch.cuda.graph',
|
||||||
|
wraps=torch.cuda.graph) as mock_cuda_graph, \
|
||||||
|
patch.object(self.model, 'forward',
|
||||||
|
wraps=self.model.forward) as mock_forward:
|
||||||
|
wrapper(self.input_tensor)
|
||||||
|
mock_cuda_graph.assert_not_called()
|
||||||
|
mock_forward.assert_called_once()
|
||||||
|
assert not wrapper.concrete_cudagraph_entries
|
||||||
|
|
||||||
|
@create_new_process_for_each_test("spawn")
|
||||||
|
def test_bypass_on_mode_none(self):
|
||||||
|
wrapper = CUDAGraphWrapper(self.model,
|
||||||
|
self.vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.FULL)
|
||||||
|
batch_descriptor = BatchDescriptor(num_tokens=10)
|
||||||
|
|
||||||
|
with set_forward_context(
|
||||||
|
attn_metadata=None,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
|
batch_descriptor=batch_descriptor), \
|
||||||
|
patch('torch.cuda.graph',
|
||||||
|
wraps=torch.cuda.graph) as mock_cuda_graph:
|
||||||
|
wrapper(self.input_tensor)
|
||||||
|
mock_cuda_graph.assert_not_called()
|
||||||
|
assert not wrapper.concrete_cudagraph_entries
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||||
|
class TestCudagraphIntegration:
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
# only FULL mode for non-uniform batches
|
||||||
|
self.comp_config = CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||||
|
cudagraph_mode="FULL",
|
||||||
|
cudagraph_capture_sizes=[10, 20])
|
||||||
|
self.vllm_config = _create_vllm_config(self.comp_config)
|
||||||
|
self.dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||||
|
self.dispatcher.initialize_cudagraph_keys(
|
||||||
|
self.comp_config.cudagraph_mode, uniform_decode_query_len=1)
|
||||||
|
|
||||||
|
def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode,
|
||||||
|
batch_descriptor):
|
||||||
|
"""Helper to run a single call and monitor the action."""
|
||||||
|
|
||||||
|
with patch('torch.cuda.graph',
|
||||||
|
wraps=torch.cuda.graph) as mock_graph_context, \
|
||||||
|
patch.object(wrapper, 'runnable',
|
||||||
|
wraps=wrapper.runnable) as mock_runnable:
|
||||||
|
|
||||||
|
entry = wrapper.concrete_cudagraph_entries.get(
|
||||||
|
batch_descriptor, None)
|
||||||
|
|
||||||
|
context = set_forward_context(attn_metadata=None,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
cudagraph_runtime_mode=runtime_mode,
|
||||||
|
batch_descriptor=batch_descriptor)
|
||||||
|
mock_replay = MagicMock()
|
||||||
|
if entry and entry.cudagraph:
|
||||||
|
with context, \
|
||||||
|
patch.object(entry.cudagraph, 'replay',
|
||||||
|
new_callable=MagicMock) as mock_replay:
|
||||||
|
wrapper(input_tensor)
|
||||||
|
else:
|
||||||
|
with context:
|
||||||
|
wrapper(input_tensor)
|
||||||
|
|
||||||
|
if mock_graph_context.called:
|
||||||
|
# note that this is globally mocked, so it will be detected
|
||||||
|
# even whether called by the inner or outer wrapper
|
||||||
|
return "capture_global"
|
||||||
|
if mock_replay.called:
|
||||||
|
# only for outer wrapper
|
||||||
|
return "replay"
|
||||||
|
if mock_runnable.call_count > 0:
|
||||||
|
# only for outer wrapper
|
||||||
|
return "bypass"
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
@create_new_process_for_each_test("spawn")
|
||||||
|
def test_capture_replay_bypass_logic(self):
|
||||||
|
model = SimpleMLP().to("cuda")
|
||||||
|
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
|
||||||
|
CUDAGraphMode.FULL)
|
||||||
|
max_bs = 16
|
||||||
|
persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
|
||||||
|
input_1 = persistent_input_buffer[:1]
|
||||||
|
input_2 = persistent_input_buffer[:2]
|
||||||
|
input_3 = persistent_input_buffer[:3]
|
||||||
|
|
||||||
|
desc_1 = BatchDescriptor(num_tokens=1)
|
||||||
|
desc_2 = BatchDescriptor(num_tokens=2)
|
||||||
|
desc_3_unseen = BatchDescriptor(num_tokens=3)
|
||||||
|
|
||||||
|
# 0. global warmup
|
||||||
|
with set_forward_context(attn_metadata=None,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
|
batch_descriptor=None):
|
||||||
|
full_wrapper(input_1)
|
||||||
|
|
||||||
|
rt_mode, key = self.dispatcher.dispatch(desc_1)
|
||||||
|
# 1. Capture first shape
|
||||||
|
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
|
||||||
|
key)
|
||||||
|
assert action == "capture_global"
|
||||||
|
|
||||||
|
# 2. Replay first shape
|
||||||
|
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
|
||||||
|
key)
|
||||||
|
assert action == "replay"
|
||||||
|
|
||||||
|
rt_mode, key = self.dispatcher.dispatch(desc_2)
|
||||||
|
# 3. Capture second shape
|
||||||
|
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode,
|
||||||
|
key)
|
||||||
|
assert action == "capture_global"
|
||||||
|
|
||||||
|
# 4. Replay second shape
|
||||||
|
action = self._run_and_monitor_call(full_wrapper, input_2,
|
||||||
|
CUDAGraphMode.FULL, desc_2)
|
||||||
|
assert action == "replay"
|
||||||
|
|
||||||
|
# 5. Bypass if no key match
|
||||||
|
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
|
||||||
|
assert rt_mode == CUDAGraphMode.NONE
|
||||||
|
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode,
|
||||||
|
key)
|
||||||
|
assert action == "bypass"
|
||||||
|
|
||||||
|
# capture unseen shape is not allowed after disable
|
||||||
|
set_cudagraph_capturing_enabled(False)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
self._run_and_monitor_call(full_wrapper, input_3,
|
||||||
|
CUDAGraphMode.FULL, desc_3_unseen)
|
||||||
|
set_cudagraph_capturing_enabled(True)
|
||||||
|
|
||||||
|
@create_new_process_for_each_test("spawn")
|
||||||
|
def test_nested_wrappers(self):
|
||||||
|
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
|
||||||
|
model = SimpleMLP().to("cuda")
|
||||||
|
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
|
||||||
|
CUDAGraphMode.FULL)
|
||||||
|
input_1 = torch.randn(1, 10, device="cuda")
|
||||||
|
|
||||||
|
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
|
||||||
|
inner_model = SimpleMLP().to("cuda")
|
||||||
|
piecewise_wrapper = CUDAGraphWrapper(inner_model, self.vllm_config,
|
||||||
|
CUDAGraphMode.PIECEWISE)
|
||||||
|
inner_model.forward = MagicMock(wraps=inner_model.forward)
|
||||||
|
outer_model = SimpleMLP().to("cuda")
|
||||||
|
# When outer model is called, it calls the piecewise_wrapper
|
||||||
|
outer_model.forward = MagicMock(wraps=outer_model.forward,
|
||||||
|
side_effect=piecewise_wrapper)
|
||||||
|
full_wrapper = CUDAGraphWrapper(outer_model, self.vllm_config,
|
||||||
|
CUDAGraphMode.FULL)
|
||||||
|
|
||||||
|
desc_1 = BatchDescriptor(num_tokens=1)
|
||||||
|
|
||||||
|
# 0. global warmup
|
||||||
|
with set_forward_context(attn_metadata=None,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
|
batch_descriptor=None):
|
||||||
|
full_wrapper(input_1)
|
||||||
|
|
||||||
|
# --- Test runtime mode FULL---
|
||||||
|
# Run with FULL mode context. Expect outer wrapper to capture.
|
||||||
|
# The inner mock should be called once inside the graph capture.
|
||||||
|
outer_model.forward.reset_mock()
|
||||||
|
inner_model.forward.reset_mock()
|
||||||
|
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||||
|
CUDAGraphMode.FULL, desc_1)
|
||||||
|
assert action == "capture_global"
|
||||||
|
assert outer_model.forward.call_count == 1
|
||||||
|
assert inner_model.forward.call_count == 1
|
||||||
|
|
||||||
|
# Run again. Expect outer wrapper to replay.
|
||||||
|
# The outer model should NOT be called because the whole graph
|
||||||
|
# is replayed.
|
||||||
|
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||||
|
CUDAGraphMode.FULL, desc_1)
|
||||||
|
assert action == "replay"
|
||||||
|
assert outer_model.forward.call_count == 1 # No new call
|
||||||
|
assert inner_model.forward.call_count == 1
|
||||||
|
|
||||||
|
# --- Test runtime mode PIECEWISE ---
|
||||||
|
outer_model.forward.reset_mock()
|
||||||
|
inner_model.forward.reset_mock()
|
||||||
|
# Run with PIECEWISE mode context.
|
||||||
|
# Expect outer wrapper to bypass and call inner wrapper.
|
||||||
|
# Inner wrapper should capture.
|
||||||
|
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||||
|
CUDAGraphMode.PIECEWISE, desc_1)
|
||||||
|
assert action == "capture_global"
|
||||||
|
assert outer_model.forward.call_count == 1
|
||||||
|
assert inner_model.forward.call_count == 1
|
||||||
|
|
||||||
|
# Run again with PIECEWISE.
|
||||||
|
# Outer bypasses, inner replays.
|
||||||
|
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||||
|
CUDAGraphMode.PIECEWISE, desc_1)
|
||||||
|
assert action == "bypass"
|
||||||
|
assert outer_model.forward.call_count == 2
|
||||||
|
assert inner_model.forward.call_count == 1
|
||||||
187
tests/v1/cudagraph/test_cudagraph_mode.py
Normal file
187
tests/v1/cudagraph/test_cudagraph_mode.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
import weakref
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.utils import wait_for_gpu_memory_to_clear
|
||||||
|
from vllm import LLM
|
||||||
|
from vllm.config import CompilationConfig
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def temporary_environ(env_vars):
|
||||||
|
"""
|
||||||
|
Temporarily set environment variables and restore them afterward.
|
||||||
|
We have to do this vs monkeypatch because monkeypatch doesn't work
|
||||||
|
with "module" scoped fixtures.
|
||||||
|
"""
|
||||||
|
original_env = {k: os.environ.get(k) for k in env_vars}
|
||||||
|
try:
|
||||||
|
os.environ.update(env_vars)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for k, v in original_env.items():
|
||||||
|
if v is None:
|
||||||
|
os.environ.pop(k, None)
|
||||||
|
else:
|
||||||
|
os.environ[k] = v
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackendConfig:
|
||||||
|
name: str
|
||||||
|
env_vars: dict
|
||||||
|
comp_config: dict
|
||||||
|
specific_gpu_arch: Optional[tuple] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Define all backend configurations of full cudagraph to be tested
|
||||||
|
backend_configs = {
|
||||||
|
# FA3 on Hopper
|
||||||
|
"FA3":
|
||||||
|
BackendConfig(name="FA3",
|
||||||
|
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(9, 0)),
|
||||||
|
# FlashMLA on Hopper
|
||||||
|
"FlashMLA":
|
||||||
|
BackendConfig(name="FlashMLA",
|
||||||
|
env_vars={
|
||||||
|
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||||
|
},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
},
|
||||||
|
specific_gpu_arch=(9, 0)),
|
||||||
|
# FA2
|
||||||
|
"FA2":
|
||||||
|
BackendConfig(name="FA2",
|
||||||
|
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
}),
|
||||||
|
# Triton Attention
|
||||||
|
"TritonAttn":
|
||||||
|
BackendConfig(name="TritonAttn",
|
||||||
|
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
}),
|
||||||
|
# FlashInfer
|
||||||
|
"FlashInfer":
|
||||||
|
BackendConfig(name="FlashInfer",
|
||||||
|
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
# test attention backend and cudagraph_mode combo
|
||||||
|
# (backend_name, cudagraph_mode, supported)
|
||||||
|
combo_cases_1 = [
|
||||||
|
("FA3", "FULL", True),
|
||||||
|
("FA3", "FULL_AND_PIECEWISE", True),
|
||||||
|
("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
||||||
|
("FA2", "FULL_AND_PIECEWISE", True),
|
||||||
|
("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
||||||
|
("FlashInfer", "FULL_AND_PIECEWISE", True),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("combo_case", combo_cases_1)
|
||||||
|
def test_backend_and_cudagraph_mode_combo(combo_case):
|
||||||
|
backend_name, cudagraph_mode, supported = combo_case
|
||||||
|
if backend_name == "FlashInfer":
|
||||||
|
try:
|
||||||
|
import flashinfer # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("FlashInfer is not installed")
|
||||||
|
backend_config = backend_configs[backend_name]
|
||||||
|
# Dynamically skip test if GPU capability is not met
|
||||||
|
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
|
||||||
|
!= current_platform.get_device_capability():
|
||||||
|
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||||
|
|
||||||
|
env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}
|
||||||
|
|
||||||
|
with temporary_environ(env_vars), ExitStack() as stack:
|
||||||
|
if not supported:
|
||||||
|
stack.enter_context(pytest.raises(Exception))
|
||||||
|
|
||||||
|
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||||
|
max_num_seqs=256,
|
||||||
|
trust_remote_code=True,
|
||||||
|
gpu_memory_utilization=0.45,
|
||||||
|
max_model_len=1024,
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
level=3, cudagraph_mode=cudagraph_mode))
|
||||||
|
llm.generate(["Hello, my name is"] * 10)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = weakref.proxy(llm)
|
||||||
|
del llm
|
||||||
|
except UnboundLocalError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
wait_for_gpu_memory_to_clear(
|
||||||
|
devices=[0],
|
||||||
|
threshold_ratio=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# test cudagraph_mode with different compilation level.
|
||||||
|
# (backend_name, cudagraph_mode, compilation_level, supported)
|
||||||
|
combo_cases_2 = [
|
||||||
|
("FA2", "FULL", 0, True), # no compilation + full cudagraph
|
||||||
|
("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph
|
||||||
|
("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph
|
||||||
|
("FA2", "PIECEWISE", 3,
|
||||||
|
True), # piecewise compilation + piecewise cudagraph
|
||||||
|
("FA2", "FULL_AND_PIECEWISE", 0,
|
||||||
|
False), # piecewise cudagraph not supported without piecewise compilation
|
||||||
|
("FA2", "FULL_AND_PIECEWISE", 3, True),
|
||||||
|
("FA2", "FULL_DECODE_ONLY", 0, True),
|
||||||
|
("FA2", "FULL_DECODE_ONLY", 3, True),
|
||||||
|
("FA2", "NONE", 0, True), # no compilation + no cudagraph
|
||||||
|
("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("combo_case", combo_cases_2)
|
||||||
|
def test_cudagraph_compilation_combo(combo_case):
|
||||||
|
backend_name, cudagraph_mode, compilation_level, supported\
|
||||||
|
= combo_case
|
||||||
|
|
||||||
|
env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}
|
||||||
|
|
||||||
|
with temporary_environ(env_vars), ExitStack() as stack:
|
||||||
|
if not supported:
|
||||||
|
stack.enter_context(pytest.raises(Exception))
|
||||||
|
|
||||||
|
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||||
|
max_num_seqs=256,
|
||||||
|
trust_remote_code=True,
|
||||||
|
gpu_memory_utilization=0.45,
|
||||||
|
max_model_len=1024,
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
level=compilation_level, cudagraph_mode=cudagraph_mode))
|
||||||
|
llm.generate(["Hello, my name is"] * 10)
|
||||||
|
try:
|
||||||
|
llm = weakref.proxy(llm)
|
||||||
|
del llm
|
||||||
|
except UnboundLocalError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
wait_for_gpu_memory_to_clear(
|
||||||
|
devices=[0],
|
||||||
|
threshold_ratio=0.1,
|
||||||
|
)
|
||||||
@ -162,6 +162,12 @@ def test_eagle_correctness(
|
|||||||
mm_enabled: bool,
|
mm_enabled: bool,
|
||||||
attn_backend: str,
|
attn_backend: str,
|
||||||
):
|
):
|
||||||
|
if attn_backend == "TREE_ATTN":
|
||||||
|
# TODO: Fix this flaky test
|
||||||
|
pytest.skip(
|
||||||
|
"TREE_ATTN is flaky in the test disable for now until it can be "
|
||||||
|
"reolved (see https://github.com/vllm-project/vllm/issues/22922)")
|
||||||
|
|
||||||
# Generate test prompts inside the function instead of using fixture
|
# Generate test prompts inside the function instead of using fixture
|
||||||
test_prompts = get_test_prompts(mm_enabled)
|
test_prompts = get_test_prompts(mm_enabled)
|
||||||
'''
|
'''
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.assets.image import ImageAsset
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.utils import set_default_torch_num_threads
|
from vllm.utils import set_default_torch_num_threads
|
||||||
@ -398,3 +399,89 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch):
|
|||||||
|
|
||||||
# Test 3: Verify healthy engine still works after mock
|
# Test 3: Verify healthy engine still works after mock
|
||||||
await engine.check_health()
|
await engine.check_health()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abort_final_output(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
output_kind: RequestOutputKind,
|
||||||
|
):
|
||||||
|
"""Test that abort() returns a final output with correct information."""
|
||||||
|
|
||||||
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
with set_default_torch_num_threads(1):
|
||||||
|
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||||
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
|
request_id = "test-abort-final-output"
|
||||||
|
|
||||||
|
# Start a long-running request
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
max_tokens=3000, # Long enough to allow abort
|
||||||
|
ignore_eos=True,
|
||||||
|
output_kind=output_kind,
|
||||||
|
temperature=0.5,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs: list[RequestOutput] = []
|
||||||
|
generated = asyncio.create_task(
|
||||||
|
collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params,
|
||||||
|
outputs))
|
||||||
|
|
||||||
|
# Let it generate some tokens
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Abort the request
|
||||||
|
await engine.abort(request_id)
|
||||||
|
|
||||||
|
# Wait for generation to complete and return final output
|
||||||
|
final_output = await generated
|
||||||
|
|
||||||
|
# Verify we got a final output
|
||||||
|
assert final_output is not None
|
||||||
|
assert final_output.finished
|
||||||
|
assert len(final_output.outputs) == 1
|
||||||
|
|
||||||
|
assert final_output.outputs[0].finish_reason == "abort"
|
||||||
|
assert final_output.outputs[0].stop_reason is None
|
||||||
|
|
||||||
|
# Verify num_cached_tokens is set correctly
|
||||||
|
assert hasattr(final_output, 'num_cached_tokens')
|
||||||
|
assert final_output.num_cached_tokens >= 0
|
||||||
|
|
||||||
|
# If we got intermediate outputs, verify they are consistent
|
||||||
|
if output_kind == RequestOutputKind.DELTA:
|
||||||
|
# For DELTA, sum all intermediate tokens should <= final tokens
|
||||||
|
token_count = sum(
|
||||||
|
len(output.outputs[0].token_ids) for output in outputs)
|
||||||
|
assert token_count > 0
|
||||||
|
assert len(final_output.outputs[0].token_ids) == 0
|
||||||
|
else:
|
||||||
|
# For FINAL_ONLY, we should only get the final output
|
||||||
|
assert len(outputs) == 0
|
||||||
|
assert len(final_output.outputs[0].token_ids) > 0
|
||||||
|
|
||||||
|
assert not engine.output_processor.has_unfinished_requests()
|
||||||
|
|
||||||
|
|
||||||
|
async def collect_outputs(
|
||||||
|
engine: AsyncLLM,
|
||||||
|
request_id: str,
|
||||||
|
prompt: PromptType,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
outputs_list: list[RequestOutput],
|
||||||
|
) -> Optional[RequestOutput]:
|
||||||
|
"""Helper to collect outputs and return the final one."""
|
||||||
|
final_output: Optional[RequestOutput] = None
|
||||||
|
async for output in engine.generate(request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
sampling_params=sampling_params):
|
||||||
|
if not output.finished:
|
||||||
|
outputs_list.append(output)
|
||||||
|
final_output = output
|
||||||
|
return final_output
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
@ -41,12 +43,15 @@ class MultinodeInternalLBServerManager:
|
|||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.api_server_count = api_server_count
|
self.api_server_count = api_server_count
|
||||||
self.base_server_args = base_server_args
|
self.base_server_args = base_server_args
|
||||||
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
|
self.servers: list[Optional[tuple[RemoteOpenAIServer,
|
||||||
|
list[str]]]] = [None] * (dp_size //
|
||||||
|
dp_per_node)
|
||||||
self.server_threads: list[threading.Thread] = []
|
self.server_threads: list[threading.Thread] = []
|
||||||
|
|
||||||
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
|
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
|
||||||
"""Start all server instances for multi-node internal LB mode."""
|
"""Start all server instances for multi-node internal LB mode."""
|
||||||
for rank in range(0, self.dp_size, self.dp_per_node):
|
for server_idx, rank in enumerate(
|
||||||
|
range(0, self.dp_size, self.dp_per_node)):
|
||||||
# Create server args for this specific rank
|
# Create server args for this specific rank
|
||||||
server_args = self.base_server_args.copy()
|
server_args = self.base_server_args.copy()
|
||||||
|
|
||||||
@ -87,7 +92,7 @@ class MultinodeInternalLBServerManager:
|
|||||||
])
|
])
|
||||||
|
|
||||||
# Use a thread to start each server to allow parallel initialization
|
# Use a thread to start each server to allow parallel initialization
|
||||||
def start_server(r: int, sargs: list[str]):
|
def start_server(sidx: int, r: int, sargs: list[str]):
|
||||||
gpus_per_node = self.tp_size * self.dp_per_node
|
gpus_per_node = self.tp_size * self.dp_per_node
|
||||||
try:
|
try:
|
||||||
# Start the server
|
# Start the server
|
||||||
@ -110,13 +115,14 @@ class MultinodeInternalLBServerManager:
|
|||||||
f"{self.api_server_count} API servers")
|
f"{self.api_server_count} API servers")
|
||||||
else:
|
else:
|
||||||
print(f"Headless node (rank {r}) started successfully")
|
print(f"Headless node (rank {r}) started successfully")
|
||||||
self.servers.append((server, sargs))
|
self.servers[sidx] = (server, sargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to start server rank {r}: {e}")
|
print(f"Failed to start server rank {r}: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
thread = threading.Thread(target=start_server,
|
thread = threading.Thread(target=start_server,
|
||||||
args=(rank, server_args))
|
args=(server_idx, rank, server_args))
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
self.server_threads.append(thread)
|
self.server_threads.append(thread)
|
||||||
@ -128,18 +134,20 @@ class MultinodeInternalLBServerManager:
|
|||||||
# Give servers additional time to fully initialize and coordinate
|
# Give servers additional time to fully initialize and coordinate
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
||||||
if len(self.servers) != self.dp_size // self.dp_per_node:
|
if not all(self.servers):
|
||||||
raise Exception("Servers failed to start")
|
raise Exception("Servers failed to start")
|
||||||
|
|
||||||
return self.servers
|
return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers)
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""Stop all server instances."""
|
"""Stop all server instances."""
|
||||||
while self.servers:
|
while self.servers:
|
||||||
try:
|
if server := self.servers.pop():
|
||||||
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
|
try:
|
||||||
except Exception as e:
|
server[0].__exit__(exc_type, exc_val, exc_tb)
|
||||||
print(f"Error stopping server: {e}")
|
except Exception as e:
|
||||||
|
print(f"Error stopping server: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
class APIOnlyServerManager:
|
class APIOnlyServerManager:
|
||||||
@ -157,7 +165,8 @@ class APIOnlyServerManager:
|
|||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.api_server_count = api_server_count
|
self.api_server_count = api_server_count
|
||||||
self.base_server_args = base_server_args
|
self.base_server_args = base_server_args
|
||||||
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
|
self.servers: list[Optional[tuple[RemoteOpenAIServer,
|
||||||
|
list[str]]]] = [None] * 2
|
||||||
self.server_threads: list[threading.Thread] = []
|
self.server_threads: list[threading.Thread] = []
|
||||||
|
|
||||||
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
|
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
|
||||||
@ -209,7 +218,7 @@ class APIOnlyServerManager:
|
|||||||
server.__enter__()
|
server.__enter__()
|
||||||
print(f"API-only server started successfully with "
|
print(f"API-only server started successfully with "
|
||||||
f"{self.api_server_count} API servers")
|
f"{self.api_server_count} API servers")
|
||||||
self.servers.append((server, api_server_args))
|
self.servers[0] = (server, api_server_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to start API-only server: {e}")
|
print(f"Failed to start API-only server: {e}")
|
||||||
raise
|
raise
|
||||||
@ -231,7 +240,7 @@ class APIOnlyServerManager:
|
|||||||
server.__enter__()
|
server.__enter__()
|
||||||
print(f"Headless engines server started successfully with "
|
print(f"Headless engines server started successfully with "
|
||||||
f"{self.dp_size} engines")
|
f"{self.dp_size} engines")
|
||||||
self.servers.append((server, engines_server_args))
|
self.servers[1] = (server, engines_server_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to start headless engines server: {e}")
|
print(f"Failed to start headless engines server: {e}")
|
||||||
raise
|
raise
|
||||||
@ -253,18 +262,20 @@ class APIOnlyServerManager:
|
|||||||
# Give servers additional time to fully initialize and coordinate
|
# Give servers additional time to fully initialize and coordinate
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
||||||
if len(self.servers) != 2:
|
if not all(self.servers):
|
||||||
raise Exception("Both servers failed to start")
|
raise Exception("Both servers failed to start")
|
||||||
|
|
||||||
return self.servers
|
return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers)
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""Stop both server instances."""
|
"""Stop both server instances."""
|
||||||
while self.servers:
|
while self.servers:
|
||||||
try:
|
if server := self.servers.pop():
|
||||||
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
|
try:
|
||||||
except Exception as e:
|
server[0].__exit__(exc_type, exc_val, exc_tb)
|
||||||
print(f"Error stopping server: {e}")
|
except Exception as e:
|
||||||
|
print(f"Error stopping server: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -560,7 +571,7 @@ async def test_api_only_multinode_dp_completion(
|
|||||||
assert len(results) == num_requests
|
assert len(results) == num_requests
|
||||||
assert all(completion is not None for completion in results)
|
assert all(completion is not None for completion in results)
|
||||||
|
|
||||||
_, api_server_args = api_only_servers[0]
|
api_server, api_server_args = api_only_servers[0]
|
||||||
api_server_count = (
|
api_server_count = (
|
||||||
api_server_args.count('--api-server-count')
|
api_server_args.count('--api-server-count')
|
||||||
and api_server_args[api_server_args.index('--api-server-count') + 1]
|
and api_server_args[api_server_args.index('--api-server-count') + 1]
|
||||||
@ -569,7 +580,6 @@ async def test_api_only_multinode_dp_completion(
|
|||||||
f"engines on headless server (API server count: {api_server_count})")
|
f"engines on headless server (API server count: {api_server_count})")
|
||||||
|
|
||||||
# Check request balancing via Prometheus metrics
|
# Check request balancing via Prometheus metrics
|
||||||
api_server = api_only_servers[0][0]
|
|
||||||
check_request_balancing(api_server, DP_SIZE)
|
check_request_balancing(api_server, DP_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -772,6 +772,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
|||||||
head_dim=hf_config.mamba_d_head,
|
head_dim=hf_config.mamba_d_head,
|
||||||
rms_norm_eps=hf_config.rms_norm_eps,
|
rms_norm_eps=hf_config.rms_norm_eps,
|
||||||
activation=hf_config.hidden_act,
|
activation=hf_config.hidden_act,
|
||||||
|
cache_config=cache_config,
|
||||||
|
model_config=model_config,
|
||||||
prefix=key,
|
prefix=key,
|
||||||
)
|
)
|
||||||
# suppress var not used error
|
# suppress var not used error
|
||||||
|
|||||||
@ -319,38 +319,6 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
|
|||||||
repetition_penalties)
|
repetition_penalties)
|
||||||
|
|
||||||
|
|
||||||
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
|
|
||||||
input_tokens: torch.Tensor,
|
|
||||||
sampled_token_ids: torch.Tensor,
|
|
||||||
input_positions: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
|
|
||||||
block_tables: torch.Tensor) -> None:
|
|
||||||
"""Advance a step on GPU for existing inputs for a multi-step runner"""
|
|
||||||
return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
|
|
||||||
block_size, input_tokens,
|
|
||||||
sampled_token_ids,
|
|
||||||
input_positions, seq_lens,
|
|
||||||
slot_mapping, block_tables)
|
|
||||||
|
|
||||||
|
|
||||||
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
|
|
||||||
input_tokens: torch.Tensor,
|
|
||||||
sampled_token_ids: torch.Tensor,
|
|
||||||
input_positions: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
paged_kv_indices: torch.Tensor,
|
|
||||||
paged_kv_indptr: torch.Tensor,
|
|
||||||
paged_kv_last_page_len: torch.Tensor,
|
|
||||||
block_table_bound: torch.Tensor) -> None:
|
|
||||||
|
|
||||||
return torch.ops._C.advance_step_flashinfer(
|
|
||||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
|
||||||
input_positions, seq_lens, slot_mapping, block_tables,
|
|
||||||
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
|
|
||||||
block_table_bound)
|
|
||||||
|
|
||||||
|
|
||||||
# fused quant layer norm ops
|
# fused quant layer norm ops
|
||||||
def rms_norm_dynamic_per_token_quant(
|
def rms_norm_dynamic_per_token_quant(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
@ -452,6 +420,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
||||||
c: Optional[torch.Tensor],
|
c: Optional[torch.Tensor],
|
||||||
b_q_weight: torch.Tensor,
|
b_q_weight: torch.Tensor,
|
||||||
|
b_bias: Optional[torch.Tensor],
|
||||||
b_scales: torch.Tensor,
|
b_scales: torch.Tensor,
|
||||||
global_scale: Optional[torch.Tensor],
|
global_scale: Optional[torch.Tensor],
|
||||||
b_zeros: Optional[torch.Tensor],
|
b_zeros: Optional[torch.Tensor],
|
||||||
@ -1048,6 +1017,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
|||||||
def gptq_marlin_gemm(a: torch.Tensor,
|
def gptq_marlin_gemm(a: torch.Tensor,
|
||||||
c: Optional[torch.Tensor],
|
c: Optional[torch.Tensor],
|
||||||
b_q_weight: torch.Tensor,
|
b_q_weight: torch.Tensor,
|
||||||
|
b_bias: Optional[torch.Tensor],
|
||||||
b_scales: torch.Tensor,
|
b_scales: torch.Tensor,
|
||||||
global_scale: Optional[torch.Tensor],
|
global_scale: Optional[torch.Tensor],
|
||||||
b_zeros: Optional[torch.Tensor],
|
b_zeros: Optional[torch.Tensor],
|
||||||
@ -1062,7 +1032,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
|
|||||||
use_atomic_add: bool = False,
|
use_atomic_add: bool = False,
|
||||||
use_fp32_reduce: bool = False,
|
use_fp32_reduce: bool = False,
|
||||||
is_zp_float: bool = False) -> torch.Tensor:
|
is_zp_float: bool = False) -> torch.Tensor:
|
||||||
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
|
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_bias, b_scales,
|
||||||
global_scale, b_zeros, g_idx, perm,
|
global_scale, b_zeros, g_idx, perm,
|
||||||
workspace, b_q_type.id, size_m,
|
workspace, b_q_type.id, size_m,
|
||||||
size_n, size_k, is_k_full,
|
size_n, size_k, is_k_full,
|
||||||
@ -1540,7 +1510,9 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
|
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
|
||||||
b_qweight: torch.Tensor, b_scales: torch.Tensor,
|
b_qweight: torch.Tensor,
|
||||||
|
b_bias: Optional[torch.Tensor],
|
||||||
|
b_scales: torch.Tensor,
|
||||||
global_scale: Optional[torch.Tensor],
|
global_scale: Optional[torch.Tensor],
|
||||||
b_qzeros: Optional[torch.Tensor],
|
b_qzeros: Optional[torch.Tensor],
|
||||||
g_idx: Optional[torch.Tensor],
|
g_idx: Optional[torch.Tensor],
|
||||||
@ -1556,11 +1528,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
|
|||||||
use_fp32_reduce: bool,
|
use_fp32_reduce: bool,
|
||||||
is_zp_float: bool) -> torch.Tensor:
|
is_zp_float: bool) -> torch.Tensor:
|
||||||
return torch.ops._moe_C.moe_wna16_marlin_gemm(
|
return torch.ops._moe_C.moe_wna16_marlin_gemm(
|
||||||
input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx,
|
input, output, b_qweight, b_bias, b_scales, global_scale, b_qzeros,
|
||||||
perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded,
|
g_idx, perm, workspace, sorted_token_ids, expert_ids,
|
||||||
topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep,
|
num_tokens_past_padded, topk_weights, moe_block_size, top_k,
|
||||||
b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add,
|
mul_topk_weights, is_ep, b_q_type.id, size_m, size_n, size_k,
|
||||||
use_fp32_reduce, is_zp_float)
|
is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float)
|
||||||
|
|
||||||
|
|
||||||
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||||
|
|||||||
@ -101,11 +101,6 @@ class AttentionBackend(ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def advance_step(self, model_input: "ModelRunnerInputBase",
|
|
||||||
sampled_token_ids: Optional[torch.Tensor],
|
|
||||||
block_size: int, num_seqs: int, num_queries: int) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def full_cls_name(cls) -> tuple[str, str]:
|
def full_cls_name(cls) -> tuple[str, str]:
|
||||||
return (cls.__module__, cls.__qualname__)
|
return (cls.__module__, cls.__qualname__)
|
||||||
|
|||||||
@ -35,8 +35,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
|||||||
flash_attn_with_kvcache)
|
flash_attn_with_kvcache)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -326,79 +325,6 @@ class DifferentialFlashAttentionMetadata(AttentionMetadata):
|
|||||||
cross_block_tables=self.cross_block_tables)
|
cross_block_tables=self.cross_block_tables)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
def advance_step(self,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
sampled_token_ids: Optional[torch.Tensor],
|
|
||||||
block_size: int,
|
|
||||||
num_seqs: int,
|
|
||||||
num_queries: int,
|
|
||||||
turn_prefills_into_decodes: bool = False):
|
|
||||||
"""
|
|
||||||
Update metadata in-place to advance one decode step.
|
|
||||||
"""
|
|
||||||
# When using cudagraph, the num_seqs is padded to the next captured
|
|
||||||
# batch sized, but num_queries tracks the actual number of requests in
|
|
||||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
||||||
if num_seqs != num_queries:
|
|
||||||
assert num_seqs > num_queries
|
|
||||||
assert self.use_cuda_graph
|
|
||||||
|
|
||||||
if turn_prefills_into_decodes:
|
|
||||||
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
|
||||||
# decodes are scheduled together. In the first step, all the
|
|
||||||
# prefills turn into decodes. This update reflects that
|
|
||||||
# conversion.
|
|
||||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
|
||||||
self.num_decode_tokens += self.num_prefills
|
|
||||||
self.num_prefills = 0
|
|
||||||
self.num_prefill_tokens = 0
|
|
||||||
self.max_prefill_seq_len = 0
|
|
||||||
self.max_query_len = 1
|
|
||||||
|
|
||||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
|
||||||
else:
|
|
||||||
assert self.seq_lens is not None
|
|
||||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
|
||||||
|
|
||||||
assert self.num_prefills == 0
|
|
||||||
assert self.num_prefill_tokens == 0
|
|
||||||
assert self.num_decode_tokens == num_seqs
|
|
||||||
assert self.slot_mapping.shape == (num_seqs, )
|
|
||||||
|
|
||||||
assert self.seq_lens is not None
|
|
||||||
assert len(self.seq_lens) == num_seqs
|
|
||||||
assert self.seq_lens_tensor is not None
|
|
||||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
||||||
assert self.max_query_len == 1
|
|
||||||
assert self.max_prefill_seq_len == 0
|
|
||||||
|
|
||||||
assert self.query_start_loc is not None
|
|
||||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
|
||||||
assert self.seq_start_loc is not None
|
|
||||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
||||||
|
|
||||||
assert self.context_lens_tensor is not None
|
|
||||||
assert self.context_lens_tensor.shape == (num_queries, )
|
|
||||||
|
|
||||||
assert self.block_tables is not None
|
|
||||||
assert self.block_tables.shape[0] == num_seqs
|
|
||||||
|
|
||||||
# Update query lengths. Note that we update only queries and not seqs,
|
|
||||||
# since tensors may be padded due to captured cuda graph batch size
|
|
||||||
for i in range(num_queries):
|
|
||||||
self.seq_lens[i] += 1
|
|
||||||
self.max_decode_seq_len = max(self.seq_lens)
|
|
||||||
|
|
||||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
|
||||||
num_queries=num_queries,
|
|
||||||
block_size=block_size,
|
|
||||||
input_tokens=model_input.input_tokens,
|
|
||||||
sampled_token_ids=sampled_token_ids,
|
|
||||||
input_positions=model_input.input_positions,
|
|
||||||
seq_lens=self.seq_lens_tensor,
|
|
||||||
slot_mapping=self.slot_mapping,
|
|
||||||
block_tables=self.block_tables)
|
|
||||||
|
|
||||||
|
|
||||||
class DifferentialFlashAttentionMetadataBuilder(
|
class DifferentialFlashAttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]):
|
AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]):
|
||||||
|
|||||||
@ -32,8 +32,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
|||||||
flash_attn_with_kvcache)
|
flash_attn_with_kvcache)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -309,79 +308,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
cross_block_tables=self.cross_block_tables)
|
cross_block_tables=self.cross_block_tables)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
def advance_step(self,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
sampled_token_ids: Optional[torch.Tensor],
|
|
||||||
block_size: int,
|
|
||||||
num_seqs: int,
|
|
||||||
num_queries: int,
|
|
||||||
turn_prefills_into_decodes: bool = False):
|
|
||||||
"""
|
|
||||||
Update metadata in-place to advance one decode step.
|
|
||||||
"""
|
|
||||||
# When using cudagraph, the num_seqs is padded to the next captured
|
|
||||||
# batch sized, but num_queries tracks the actual number of requests in
|
|
||||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
||||||
if num_seqs != num_queries:
|
|
||||||
assert num_seqs > num_queries
|
|
||||||
assert self.use_cuda_graph
|
|
||||||
|
|
||||||
if turn_prefills_into_decodes:
|
|
||||||
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
|
||||||
# decodes are scheduled together. In the first step, all the
|
|
||||||
# prefills turn into decodes. This update reflects that
|
|
||||||
# conversion.
|
|
||||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
|
||||||
self.num_decode_tokens += self.num_prefills
|
|
||||||
self.num_prefills = 0
|
|
||||||
self.num_prefill_tokens = 0
|
|
||||||
self.max_prefill_seq_len = 0
|
|
||||||
self.max_query_len = 1
|
|
||||||
|
|
||||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
|
||||||
else:
|
|
||||||
assert self.seq_lens is not None
|
|
||||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
|
||||||
|
|
||||||
assert self.num_prefills == 0
|
|
||||||
assert self.num_prefill_tokens == 0
|
|
||||||
assert self.num_decode_tokens == num_seqs
|
|
||||||
assert self.slot_mapping.shape == (num_seqs, )
|
|
||||||
|
|
||||||
assert self.seq_lens is not None
|
|
||||||
assert len(self.seq_lens) == num_seqs
|
|
||||||
assert self.seq_lens_tensor is not None
|
|
||||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
||||||
assert self.max_query_len == 1
|
|
||||||
assert self.max_prefill_seq_len == 0
|
|
||||||
|
|
||||||
assert self.query_start_loc is not None
|
|
||||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
|
||||||
assert self.seq_start_loc is not None
|
|
||||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
||||||
|
|
||||||
assert self.context_lens_tensor is not None
|
|
||||||
assert self.context_lens_tensor.shape == (num_queries, )
|
|
||||||
|
|
||||||
assert self.block_tables is not None
|
|
||||||
assert self.block_tables.shape[0] == num_seqs
|
|
||||||
|
|
||||||
# Update query lengths. Note that we update only queries and not seqs,
|
|
||||||
# since tensors may be padded due to captured cuda graph batch size
|
|
||||||
for i in range(num_queries):
|
|
||||||
self.seq_lens[i] += 1
|
|
||||||
self.max_decode_seq_len = max(self.seq_lens)
|
|
||||||
|
|
||||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
|
||||||
num_queries=num_queries,
|
|
||||||
block_size=block_size,
|
|
||||||
input_tokens=model_input.input_tokens,
|
|
||||||
sampled_token_ids=sampled_token_ids,
|
|
||||||
input_positions=model_input.input_positions,
|
|
||||||
seq_lens=self.seq_lens_tensor,
|
|
||||||
slot_mapping=self.slot_mapping,
|
|
||||||
block_tables=self.block_tables)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionMetadataBuilder(
|
class FlashAttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||||
|
|||||||
@ -51,8 +51,7 @@ from vllm.utils.flashinfer import use_trtllm_attention
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferBackend(AttentionBackend):
|
class FlashInferBackend(AttentionBackend):
|
||||||
@ -428,7 +427,7 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
query_start_loc: Optional[torch.Tensor] = None
|
query_start_loc: Optional[torch.Tensor] = None
|
||||||
block_tables: Optional[torch.Tensor] = None
|
block_tables: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# used for GPU in-place advance_step
|
# used for GPU operations
|
||||||
seq_lens_tensor: Optional[torch.Tensor] = None
|
seq_lens_tensor: Optional[torch.Tensor] = None
|
||||||
block_table_bound: Optional[torch.Tensor] = None
|
block_table_bound: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@ -587,66 +586,6 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
return None
|
return None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def advance_step(self,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
sampled_token_ids: Optional[torch.Tensor],
|
|
||||||
block_size: int,
|
|
||||||
num_seqs: int,
|
|
||||||
num_queries: int,
|
|
||||||
turn_prefills_into_decodes: bool = False):
|
|
||||||
"""
|
|
||||||
Update metadata in-place to advance one decode step.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if turn_prefills_into_decodes:
|
|
||||||
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
|
||||||
# decodes are scheduled together. In the first step, all the
|
|
||||||
# prefills turn into decodes. This update reflects that
|
|
||||||
# conversion.
|
|
||||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
|
||||||
# Flashinfer doesn't support speculative decoding + chunked-prefill
|
|
||||||
# + multi-step scheduling yet.
|
|
||||||
assert self.decode_query_len == 1
|
|
||||||
self.num_decode_tokens += self.num_prefills
|
|
||||||
self.num_prefills = 0
|
|
||||||
self.num_prefill_tokens = 0
|
|
||||||
self.max_prefill_seq_len = 0
|
|
||||||
self.max_query_len = 1
|
|
||||||
|
|
||||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
|
||||||
else:
|
|
||||||
assert self.seq_lens_tensor is not None
|
|
||||||
|
|
||||||
assert num_seqs > 0
|
|
||||||
assert num_queries > 0
|
|
||||||
assert model_input.attn_metadata is not None
|
|
||||||
assert sampled_token_ids is not None
|
|
||||||
|
|
||||||
# When using cudagraph, the num_seqs is padded to the next captured
|
|
||||||
# batch sized, but num_queries tracks the actual number of requests in
|
|
||||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
||||||
if num_seqs != num_queries:
|
|
||||||
assert num_seqs > num_queries
|
|
||||||
assert self.use_cuda_graph
|
|
||||||
|
|
||||||
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
|
|
||||||
|
|
||||||
# Update GPU tensors
|
|
||||||
ops.advance_step_flashinfer(
|
|
||||||
num_seqs=num_seqs,
|
|
||||||
num_queries=num_queries,
|
|
||||||
block_size=block_size,
|
|
||||||
input_tokens=model_input.input_tokens,
|
|
||||||
sampled_token_ids=model_input.input_tokens,
|
|
||||||
input_positions=model_input.input_positions,
|
|
||||||
seq_lens=self.seq_lens_tensor,
|
|
||||||
slot_mapping=self.slot_mapping,
|
|
||||||
block_tables=self.block_tables,
|
|
||||||
paged_kv_indices=self.paged_kv_indices,
|
|
||||||
paged_kv_indptr=self.paged_kv_indptr,
|
|
||||||
paged_kv_last_page_len=self.paged_kv_last_page_len,
|
|
||||||
block_table_bound=self.block_table_bound)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -18,9 +18,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
|||||||
get_mla_metadata,
|
get_mla_metadata,
|
||||||
is_flashmla_supported)
|
is_flashmla_supported)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMLABackend(MLACommonBackend):
|
class FlashMLABackend(MLACommonBackend):
|
||||||
|
|
||||||
@ -62,16 +59,6 @@ class FlashMLAMetadata(MLACommonMetadata):
|
|||||||
self.decode_num_splits
|
self.decode_num_splits
|
||||||
return decode_metadata
|
return decode_metadata
|
||||||
|
|
||||||
def advance_step(self,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
sampled_token_ids: Optional[torch.Tensor],
|
|
||||||
block_size: int,
|
|
||||||
num_seqs: int,
|
|
||||||
num_queries: int,
|
|
||||||
turn_prefills_into_decodes: bool = False):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"advance_step is not implemented for FlashMLA")
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||||
|
|
||||||
|
|||||||
@ -234,8 +234,7 @@ except ImportError:
|
|||||||
flash_attn_varlen_func = None
|
flash_attn_varlen_func = None
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
|
||||||
|
|
||||||
is_hip = current_platform.is_rocm()
|
is_hip = current_platform.is_rocm()
|
||||||
|
|
||||||
@ -631,90 +630,6 @@ class MLACommonMetadata(AttentionMetadata):
|
|||||||
is_profile_run=self.is_profile_run)
|
is_profile_run=self.is_profile_run)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
def advance_step(self,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
sampled_token_ids: Optional[torch.Tensor],
|
|
||||||
block_size: int,
|
|
||||||
num_seqs: int,
|
|
||||||
num_queries: int,
|
|
||||||
turn_prefills_into_decodes: bool = False):
|
|
||||||
"""
|
|
||||||
Update metadata in-place to advance one decode step.
|
|
||||||
"""
|
|
||||||
# When using cudagraph, the num_seqs is padded to the next captured
|
|
||||||
# batch sized, but num_queries tracks the actual number of requests in
|
|
||||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
||||||
if num_seqs != num_queries:
|
|
||||||
assert num_seqs > num_queries
|
|
||||||
|
|
||||||
if turn_prefills_into_decodes:
|
|
||||||
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
|
||||||
# decodes are scheduled together. In the first step, all the
|
|
||||||
# prefills turn into decodes. This update reflects that
|
|
||||||
# conversion.
|
|
||||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
|
||||||
self.num_decode_tokens += self.num_prefills
|
|
||||||
self.num_prefills = 0
|
|
||||||
self.num_prefill_tokens = 0
|
|
||||||
self.max_prefill_seq_len = 0
|
|
||||||
self.max_query_len = 1
|
|
||||||
|
|
||||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
|
||||||
else:
|
|
||||||
assert self.seq_lens is not None
|
|
||||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
|
||||||
|
|
||||||
assert self.num_prefills == 0
|
|
||||||
assert self.num_prefill_tokens == 0
|
|
||||||
assert self.num_decode_tokens == num_seqs
|
|
||||||
assert self.slot_mapping.shape == (num_seqs, )
|
|
||||||
|
|
||||||
assert self.seq_lens is not None
|
|
||||||
assert len(self.seq_lens) == num_seqs
|
|
||||||
assert self.seq_lens_tensor is not None
|
|
||||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
||||||
assert self.max_query_len == 1
|
|
||||||
assert self.max_prefill_seq_len == 0
|
|
||||||
|
|
||||||
assert self.query_start_loc is not None
|
|
||||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
|
||||||
assert self.seq_start_loc is not None
|
|
||||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
||||||
|
|
||||||
assert self.context_lens_tensor is not None
|
|
||||||
assert self.context_lens_tensor.shape == (num_queries, )
|
|
||||||
|
|
||||||
assert self.block_tables is not None
|
|
||||||
assert self.block_tables.shape[0] == num_seqs
|
|
||||||
|
|
||||||
# Update query lengths. Note that we update only queries and not seqs,
|
|
||||||
# since tensors may be padded due to captured cuda graph batch size
|
|
||||||
for i in range(num_queries):
|
|
||||||
self.seq_lens[i] += 1
|
|
||||||
self.max_decode_seq_len = max(self.seq_lens)
|
|
||||||
|
|
||||||
self._ops_advance_step(num_seqs=num_seqs,
|
|
||||||
num_queries=num_queries,
|
|
||||||
block_size=block_size,
|
|
||||||
input_tokens=model_input.input_tokens,
|
|
||||||
sampled_token_ids=sampled_token_ids,
|
|
||||||
input_positions=model_input.input_positions)
|
|
||||||
|
|
||||||
def _ops_advance_step(self, num_seqs: int, num_queries: int,
|
|
||||||
block_size: int, input_tokens: torch.Tensor,
|
|
||||||
sampled_token_ids: torch.Tensor,
|
|
||||||
input_positions: torch.Tensor) -> None:
|
|
||||||
# here we use advance_step_flashinfo to update the paged_kv_* tensors
|
|
||||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
|
||||||
num_queries=num_queries,
|
|
||||||
block_size=block_size,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
sampled_token_ids=sampled_token_ids,
|
|
||||||
input_positions=input_positions,
|
|
||||||
seq_lens=self.seq_lens_tensor,
|
|
||||||
slot_mapping=self.slot_mapping,
|
|
||||||
block_tables=self.block_tables)
|
|
||||||
|
|
||||||
|
|
||||||
class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -15,8 +15,7 @@ from vllm.attention.backends.utils import CommonAttentionState
|
|||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder)
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
|
||||||
from vllm.utils import async_tensor_h2d
|
from vllm.utils import async_tensor_h2d
|
||||||
|
|
||||||
# Placeholder attention backend for models like Mamba and pooling models that
|
# Placeholder attention backend for models like Mamba and pooling models that
|
||||||
@ -201,65 +200,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
|
|||||||
)
|
)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
def advance_step(self,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
sampled_token_ids: Optional[torch.Tensor],
|
|
||||||
block_size: int,
|
|
||||||
num_seqs: int,
|
|
||||||
num_queries: int,
|
|
||||||
turn_prefills_into_decodes: bool = False):
|
|
||||||
"""
|
|
||||||
Update metadata in-place to advance one decode step.
|
|
||||||
"""
|
|
||||||
# When using cudagraph, the num_seqs is padded to the next captured
|
|
||||||
# batch sized, but num_queries tracks the actual number of requests in
|
|
||||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
||||||
if num_seqs != num_queries:
|
|
||||||
assert num_seqs > num_queries
|
|
||||||
assert self.use_cuda_graph
|
|
||||||
|
|
||||||
assert not turn_prefills_into_decodes, \
|
|
||||||
("Multi-Step + Chunked-Prefill is not supported for attention-free"
|
|
||||||
"models. turn_prefills_into_decodes is a "
|
|
||||||
"Multi-Step + Chunked-Prefill specific parameter.")
|
|
||||||
|
|
||||||
assert self.seq_lens is not None
|
|
||||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
|
||||||
|
|
||||||
assert self.num_prefills == 0
|
|
||||||
assert self.num_prefill_tokens == 0
|
|
||||||
assert self.num_decode_tokens == num_seqs
|
|
||||||
|
|
||||||
assert self.seq_lens is not None
|
|
||||||
assert len(self.seq_lens) == num_seqs
|
|
||||||
assert self.seq_lens_tensor is not None
|
|
||||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
||||||
assert self.max_query_len == 1
|
|
||||||
assert self.max_prefill_seq_len == 0
|
|
||||||
|
|
||||||
assert self.query_start_loc is not None
|
|
||||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
|
||||||
assert self.seq_start_loc is not None
|
|
||||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
||||||
|
|
||||||
assert self.context_lens_tensor is not None
|
|
||||||
assert self.context_lens_tensor.shape == (num_queries, )
|
|
||||||
|
|
||||||
# Update query lengths. Note that we update only queries and not seqs,
|
|
||||||
# since tensors may be padded due to captured cuda graph batch size
|
|
||||||
for i in range(num_queries):
|
|
||||||
self.seq_lens[i] += 1
|
|
||||||
self.max_decode_seq_len = max(self.seq_lens)
|
|
||||||
|
|
||||||
# Update sequences, masking off entries greater than num_queries
|
|
||||||
device = self.seq_lens_tensor.device
|
|
||||||
mask = torch.arange(self.seq_lens_tensor.size(0),
|
|
||||||
device=device) < num_queries
|
|
||||||
self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype)
|
|
||||||
if sampled_token_ids is not None:
|
|
||||||
model_input.input_tokens.masked_scatter_(
|
|
||||||
mask, sampled_token_ids[:num_queries])
|
|
||||||
|
|
||||||
|
|
||||||
class PlaceholderAttentionMetadataBuilder(
|
class PlaceholderAttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional, Type, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm._custom_ops as ops
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
@ -107,26 +106,6 @@ class AiterMLAMetadata(MLACommonMetadata):
|
|||||||
|
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
def _ops_advance_step(self, num_seqs: int, num_queries: int,
|
|
||||||
block_size: int, input_tokens: torch.Tensor,
|
|
||||||
sampled_token_ids: torch.Tensor,
|
|
||||||
input_positions: torch.Tensor) -> None:
|
|
||||||
|
|
||||||
ops.advance_step_flashinfer(
|
|
||||||
num_seqs=num_seqs,
|
|
||||||
num_queries=num_queries,
|
|
||||||
block_size=block_size,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
sampled_token_ids=sampled_token_ids,
|
|
||||||
input_positions=input_positions,
|
|
||||||
seq_lens=self.seq_lens_tensor,
|
|
||||||
slot_mapping=self.slot_mapping,
|
|
||||||
block_tables=self.block_tables,
|
|
||||||
paged_kv_indices=self.paged_kv_indices,
|
|
||||||
paged_kv_indptr=self.paged_kv_indptr,
|
|
||||||
paged_kv_last_page_lens=self.paged_kv_last_page_lens,
|
|
||||||
block_table_bound=self.block_table_bound)
|
|
||||||
|
|
||||||
|
|
||||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||||
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
|
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -23,9 +23,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
GroupShape)
|
GroupShape)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
_PARTITION_SIZE_ROCM = 256
|
_PARTITION_SIZE_ROCM = 256
|
||||||
|
|
||||||
@ -261,69 +258,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
def advance_step(self,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
sampled_token_ids: Optional[torch.Tensor],
|
|
||||||
block_size: int,
|
|
||||||
num_seqs: int,
|
|
||||||
num_queries: int,
|
|
||||||
turn_prefills_into_decodes: bool = False):
|
|
||||||
"""
|
|
||||||
Update metadata in-place to advance one decode step.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert not turn_prefills_into_decodes, \
|
|
||||||
("Chunked prefill is not supported with rocm_flash_attn yet."
|
|
||||||
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
|
|
||||||
"specific parameter.")
|
|
||||||
|
|
||||||
# When using cudagraph, the num_seqs is padded to the next captured
|
|
||||||
# batch sized, but num_queries tracks the actual number of requests in
|
|
||||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
||||||
if num_seqs != num_queries:
|
|
||||||
assert num_seqs > num_queries
|
|
||||||
assert self.use_cuda_graph
|
|
||||||
|
|
||||||
assert self.num_prefills == 0
|
|
||||||
assert self.num_prefill_tokens == 0
|
|
||||||
assert self.num_decode_tokens == num_seqs
|
|
||||||
assert self.slot_mapping.shape == (num_seqs, )
|
|
||||||
|
|
||||||
assert self.seq_lens is not None
|
|
||||||
assert len(self.seq_lens) == num_seqs
|
|
||||||
assert self.seq_lens_tensor is not None
|
|
||||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
||||||
assert self.max_query_len == 1
|
|
||||||
assert self.max_prefill_seq_len == 0
|
|
||||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
|
||||||
|
|
||||||
assert self.query_start_loc is not None
|
|
||||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
|
||||||
assert self.seq_start_loc is not None
|
|
||||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
||||||
|
|
||||||
assert self.context_lens_tensor is not None
|
|
||||||
assert self.context_lens_tensor.shape == (num_queries, )
|
|
||||||
|
|
||||||
assert self.block_tables is not None
|
|
||||||
assert self.block_tables.shape[0] == num_seqs
|
|
||||||
|
|
||||||
# Update query lengths. Note that we update only queries and not seqs,
|
|
||||||
# since tensors may be padded due to captured cuda graph batch size
|
|
||||||
for i in range(num_queries):
|
|
||||||
self.seq_lens[i] += 1
|
|
||||||
self.max_decode_seq_len = max(self.seq_lens)
|
|
||||||
|
|
||||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
|
||||||
num_queries=num_queries,
|
|
||||||
block_size=block_size,
|
|
||||||
input_tokens=model_input.input_tokens,
|
|
||||||
sampled_token_ids=sampled_token_ids,
|
|
||||||
input_positions=model_input.input_positions,
|
|
||||||
seq_lens=self.seq_lens_tensor,
|
|
||||||
slot_mapping=self.slot_mapping,
|
|
||||||
block_tables=self.block_tables)
|
|
||||||
|
|
||||||
|
|
||||||
class ROCmFlashAttentionMetadataBuilder(
|
class ROCmFlashAttentionMetadataBuilder(
|
||||||
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
|
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
|
||||||
|
|||||||
@ -15,7 +15,7 @@ import torch.fx as fx
|
|||||||
from torch._dispatch.python import enable_python_dispatcher
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import CompilationConfig, VllmConfig
|
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
|
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
|
||||||
@ -277,9 +277,6 @@ def split_graph(graph: fx.GraphModule,
|
|||||||
return split_gm, outputs
|
return split_gm, outputs
|
||||||
|
|
||||||
|
|
||||||
# we share the global graph pool among all the backends
|
|
||||||
global_graph_pool = None
|
|
||||||
|
|
||||||
compilation_start_time = 0.0
|
compilation_start_time = 0.0
|
||||||
|
|
||||||
|
|
||||||
@ -339,14 +336,37 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
graph_index=index,
|
graph_index=index,
|
||||||
num_graphs=len(self.compile_submod_names),
|
num_graphs=len(self.compile_submod_names),
|
||||||
runtime_shape=None)
|
runtime_shape=None)
|
||||||
|
# Lazy import here to avoid circular import
|
||||||
|
from .cuda_graph import CUDAGraphOptions
|
||||||
|
from .cuda_piecewise_backend import PiecewiseBackend
|
||||||
|
|
||||||
piecewise_backend = resolve_obj_by_qualname(
|
piecewise_backend = PiecewiseBackend(
|
||||||
current_platform.get_piecewise_backend_cls())
|
submod, self.vllm_config, index,
|
||||||
self.module.__dict__[target] = piecewise_backend(
|
|
||||||
submod, self.vllm_config, self.graph_pool, index,
|
|
||||||
len(self.compile_submod_names), sym_shape_indices,
|
len(self.compile_submod_names), sym_shape_indices,
|
||||||
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
||||||
|
|
||||||
|
if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
|
||||||
|
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
||||||
|
# class) as platform dependent.
|
||||||
|
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||||
|
current_platform.get_static_graph_wrapper_cls())
|
||||||
|
|
||||||
|
# Always assign PIECEWISE runtime mode to the
|
||||||
|
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
||||||
|
# it from the FULL cudagraph runtime mode, no matter it
|
||||||
|
# is wrapped on a full or piecewise fx graph.
|
||||||
|
self.module.__dict__[target] = static_graph_wrapper_class(
|
||||||
|
runnable=piecewise_backend,
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
graph_pool=self.graph_pool,
|
||||||
|
cudagraph_options=CUDAGraphOptions(
|
||||||
|
debug_log_enable=piecewise_backend.is_first_graph,
|
||||||
|
gc_disable=not piecewise_backend.is_first_graph,
|
||||||
|
weak_ref_output=piecewise_backend.is_last_graph))
|
||||||
|
else:
|
||||||
|
self.module.__dict__[target] = piecewise_backend
|
||||||
|
|
||||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@ -413,9 +433,7 @@ class VllmBackend:
|
|||||||
# them, e.g. backbone (default), eagle_head, etc.
|
# them, e.g. backbone (default), eagle_head, etc.
|
||||||
self.prefix = prefix or model_tag
|
self.prefix = prefix or model_tag
|
||||||
|
|
||||||
global global_graph_pool
|
global_graph_pool = current_platform.get_global_graph_pool()
|
||||||
if global_graph_pool is None:
|
|
||||||
global_graph_pool = current_platform.graph_pool_handle()
|
|
||||||
|
|
||||||
# TODO: in the future, if we want to use multiple
|
# TODO: in the future, if we want to use multiple
|
||||||
# streams, it might not be safe to share a global pool.
|
# streams, it might not be safe to share a global pool.
|
||||||
@ -585,7 +603,7 @@ class VllmBackend:
|
|||||||
|
|
||||||
self._called = True
|
self._called = True
|
||||||
|
|
||||||
if not self.compilation_config.use_cudagraph or \
|
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
|
||||||
not self.compilation_config.cudagraph_copy_inputs:
|
not self.compilation_config.cudagraph_copy_inputs:
|
||||||
return self.split_gm
|
return self.split_gm
|
||||||
|
|
||||||
|
|||||||
@ -1,72 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from typing import Any, Callable, Protocol
|
|
||||||
|
|
||||||
import torch.fx as fx
|
|
||||||
|
|
||||||
from vllm.compilation.backends import VllmBackend
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractPiecewiseBackend(Protocol):
|
|
||||||
"""
|
|
||||||
PiecewiseBackend interface that allows platforms to extend
|
|
||||||
piecewise static graph.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
|
||||||
graph_pool: Any, piecewise_compile_index: int,
|
|
||||||
total_piecewise_compiles: int, sym_shape_indices: list[int],
|
|
||||||
compiled_graph_for_general_shape: Callable,
|
|
||||||
vllm_backend: VllmBackend, **kwargs):
|
|
||||||
"""
|
|
||||||
Initializes the PiecewiseBackend class with compilation and
|
|
||||||
execution-related configurations.
|
|
||||||
|
|
||||||
This class handles piecewise compilation, graph capturing,
|
|
||||||
and dispatching for specific input shapes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph (fx.GraphModule): The graph represented in fx.
|
|
||||||
vllm_config (VllmConfig): Global configuration for vLLM.
|
|
||||||
graph_pool (Any):
|
|
||||||
Graph memory pool handle, e.g.,
|
|
||||||
`torch.cuda.graph_pool_handle()`.
|
|
||||||
piecewise_compile_index (int):
|
|
||||||
Index of the current piecewise subgraph.
|
|
||||||
total_piecewise_compiles (int):
|
|
||||||
Total number of piecewise-compiled graphs.
|
|
||||||
sym_shape_indices (list[int]):
|
|
||||||
Indices of symbolic shape.
|
|
||||||
compiled_graph_for_general_shape (Callable):
|
|
||||||
Callable that executes the graph compiled for general shapes.
|
|
||||||
vllm_backend (VllmBackend):
|
|
||||||
Backend compiler that manages compilation and graph runtime
|
|
||||||
for vLLM.
|
|
||||||
|
|
||||||
Keyword Args:
|
|
||||||
kwargs: Additional keyword arguments reserved for future
|
|
||||||
extensions or custom platforms.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __call__(self, *args) -> Any:
|
|
||||||
"""Executes the compiled graph for given input args.
|
|
||||||
|
|
||||||
If this is the first invocation, executes the general compiled graph
|
|
||||||
and initiates the compilation process tracking. For subsequent calls,
|
|
||||||
dynamically dispatches execution to either a compiled graph or a static
|
|
||||||
graph based on the input shape.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args: Variable length input arguments to be passed into the
|
|
||||||
graph. The symbolic shape is expected to be in position
|
|
||||||
`sym_shape_indices[0]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: Output of the executed graph. This can be from the general
|
|
||||||
compiled graph, a specialized compiled version for the given shape,
|
|
||||||
or a replayed static graph.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
54
vllm/compilation/base_static_graph.py
Normal file
54
vllm/compilation/base_static_graph.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Any, Callable, Protocol
|
||||||
|
|
||||||
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractStaticGraphWrapper(Protocol):
|
||||||
|
"""
|
||||||
|
StaticGraphWrapper interface that allows platforms to wrap a callable
|
||||||
|
to be captured as a static graph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||||
|
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
|
||||||
|
"""
|
||||||
|
Initializes the StaticGraphWrapper class with graph capturing and
|
||||||
|
execution-related configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runnable (Callable): The callable to be wrapped and captured.
|
||||||
|
vllm_config (VllmConfig): Global configuration for vLLM.
|
||||||
|
runtime_mode (CUDAGraphMode): The style of the static
|
||||||
|
graph runtime. See CUDAGraphMode in vllm/config.py.
|
||||||
|
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
||||||
|
are used as concrete runtime mode for cudagraph dispatching.
|
||||||
|
graph_pool (Any):
|
||||||
|
Graph memory pool handle, e.g.,
|
||||||
|
`torch.cuda.graph_pool_handle()`.
|
||||||
|
Keyword Args:
|
||||||
|
kwargs: Additional keyword arguments for platform-specific
|
||||||
|
configurations.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Executes the wrapped callable.
|
||||||
|
|
||||||
|
If the current runtime mode in the ForwardContext matches the runtime
|
||||||
|
mode of this instance, it replays the CUDAGraph or captures it using
|
||||||
|
the callable if it hasn't been captured yet. Otherwise, it calls the
|
||||||
|
original callable directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Variable length input arguments to be passed into the
|
||||||
|
callable.
|
||||||
|
**kwargs: Keyword arguments to be passed into the callable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: Output of the executed callable.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
193
vllm/compilation/cuda_graph.py
Normal file
193
vllm/compilation/cuda_graph.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.compilation.counter import compilation_counter
|
||||||
|
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||||
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import weak_ref_tensors
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CUDAGraphEntry:
|
||||||
|
batch_descriptor: BatchDescriptor
|
||||||
|
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||||
|
output: Optional[Any] = None
|
||||||
|
|
||||||
|
# for cudagraph debugging, track the input addresses
|
||||||
|
# during capture, and check if they are the same during replay
|
||||||
|
input_addresses: Optional[list[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CUDAGraphOptions:
|
||||||
|
debug_log_enable: bool = True
|
||||||
|
gc_disable: bool = False
|
||||||
|
weak_ref_output: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class CUDAGraphWrapper:
|
||||||
|
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
|
||||||
|
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||||
|
|
||||||
|
The workflow of this wrapper in the cudagraph dispatching is as follows:
|
||||||
|
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||||
|
PIECEWISE).
|
||||||
|
2. At runtime, the wrapper receives a runtime_mode and a
|
||||||
|
batch_descriptor(key) from the forward context and blindly trust them
|
||||||
|
for cudagraph dispatching.
|
||||||
|
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||||
|
wrapper, just call the runnable directly.
|
||||||
|
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||||
|
the wrapper will perform cudagraph capture(if key does not exist, create
|
||||||
|
a new entry and cache it) or replay (if key exists in the cache).
|
||||||
|
|
||||||
|
Note: CUDAGraphWrapper does not store persistent buffers or copy any
|
||||||
|
runtime inputs into that buffers for replay. We assume implementing them
|
||||||
|
is done outside of the wrapper. That is because we do not make any
|
||||||
|
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||||
|
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||||
|
tracing and checking the input addresses to be consistent during replay is
|
||||||
|
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
runnable: Callable,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
runtime_mode: CUDAGraphMode,
|
||||||
|
graph_pool: Any = None,
|
||||||
|
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||||
|
self.runnable = runnable
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.graph_pool = graph_pool
|
||||||
|
self.runtime_mode = runtime_mode
|
||||||
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
|
||||||
|
self.first_run_finished = False
|
||||||
|
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||||
|
|
||||||
|
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
||||||
|
# need to initialize a CUDAGraphWrapper.
|
||||||
|
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||||
|
if self.graph_pool is None:
|
||||||
|
self.graph_pool = current_platform.get_global_graph_pool()
|
||||||
|
|
||||||
|
if cudagraph_options is None:
|
||||||
|
cudagraph_options = CUDAGraphOptions()
|
||||||
|
self.cudagraph_options = cudagraph_options
|
||||||
|
# the entries for different batch descriptors that we need to capture
|
||||||
|
# cudagraphs for.
|
||||||
|
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\
|
||||||
|
= {}
|
||||||
|
|
||||||
|
def __getattr__(self, key: str):
|
||||||
|
# allow accessing the attributes of the runnable.
|
||||||
|
if hasattr(self.runnable, key):
|
||||||
|
return getattr(self.runnable, key)
|
||||||
|
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||||
|
f"cudagraph wrapper: {self.runnable}")
|
||||||
|
|
||||||
|
def unwrap(self) -> Callable:
|
||||||
|
# in case we need to access the original runnable.
|
||||||
|
return self.runnable
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
batch_descriptor = forward_context.batch_descriptor
|
||||||
|
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||||
|
|
||||||
|
if cudagraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||||
|
cudagraph_runtime_mode != self.runtime_mode:
|
||||||
|
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||||
|
# running without cudagraphs.
|
||||||
|
# We do not trigger capture/replay if the runtime mode is not
|
||||||
|
# matches. This enables properly dispatching to the correct
|
||||||
|
# CUDAGraphWrapper when nesting multiple instances with different
|
||||||
|
# runtime modes.
|
||||||
|
return self.runnable(*args, **kwargs)
|
||||||
|
|
||||||
|
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||||
|
# create a new entry for this batch descriptor
|
||||||
|
self.concrete_cudagraph_entries[batch_descriptor] = \
|
||||||
|
CUDAGraphEntry(batch_descriptor=batch_descriptor)
|
||||||
|
|
||||||
|
entry = self.concrete_cudagraph_entries[batch_descriptor]
|
||||||
|
|
||||||
|
if entry.cudagraph is None:
|
||||||
|
if self.cudagraph_options.debug_log_enable:
|
||||||
|
# Since we capture cudagraph for many different shapes and
|
||||||
|
# capturing is fast, we don't need to log it for every
|
||||||
|
# shape. E.g. we only log it for the first subgraph in
|
||||||
|
# piecewise mode.
|
||||||
|
logger.debug("Capturing a cudagraph on (%s,%s)",
|
||||||
|
self.runtime_mode.name, entry.batch_descriptor)
|
||||||
|
# validate that cudagraph capturing is legal at this point.
|
||||||
|
validate_cudagraph_capturing_enabled()
|
||||||
|
|
||||||
|
input_addresses = [
|
||||||
|
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||||
|
]
|
||||||
|
entry.input_addresses = input_addresses
|
||||||
|
cudagraph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
|
with ExitStack() as stack:
|
||||||
|
if self.cudagraph_options.gc_disable:
|
||||||
|
# during every model forward for piecewise cudagraph
|
||||||
|
# mode, we will capture many pieces of cudagraphs
|
||||||
|
# (roughly one per layer). running gc again and again
|
||||||
|
# across layers will make the cudagraph capture very slow.
|
||||||
|
# therefore, we only run gc for the first graph,
|
||||||
|
# and disable gc for the rest of the graphs.
|
||||||
|
stack.enter_context(patch("gc.collect", lambda: None))
|
||||||
|
stack.enter_context(
|
||||||
|
patch("torch.cuda.empty_cache", lambda: None))
|
||||||
|
|
||||||
|
# mind-exploding: carefully manage the reference and memory.
|
||||||
|
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||||
|
# `output` is managed by pytorch's cudagraph pool
|
||||||
|
output = self.runnable(*args, **kwargs)
|
||||||
|
if self.cudagraph_options.weak_ref_output:
|
||||||
|
# by converting it to weak ref,
|
||||||
|
# the original `output` will immediately be released
|
||||||
|
# to save memory. It is only safe to do this for
|
||||||
|
# the last graph in piecewise cuadgraph mode, because
|
||||||
|
# the output of the last graph will not be used by
|
||||||
|
# any other cuda graph.
|
||||||
|
output = weak_ref_tensors(output)
|
||||||
|
|
||||||
|
# here we always use weak ref for the output
|
||||||
|
# to save memory
|
||||||
|
entry.output = weak_ref_tensors(output)
|
||||||
|
entry.cudagraph = cudagraph
|
||||||
|
|
||||||
|
compilation_counter.num_cudagraph_captured += 1
|
||||||
|
|
||||||
|
# important: we need to return the output, rather than
|
||||||
|
# the weak ref of the output, so that pytorch can correctly
|
||||||
|
# manage the memory during cuda graph capture
|
||||||
|
return output
|
||||||
|
|
||||||
|
if self.is_debugging_mode:
|
||||||
|
# check if the input addresses are the same
|
||||||
|
new_input_addresses = [
|
||||||
|
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||||
|
]
|
||||||
|
assert new_input_addresses == entry.input_addresses, (
|
||||||
|
f"Input addresses for cudagraphs are different "
|
||||||
|
f"during replay. Expected {entry.input_addresses}, "
|
||||||
|
f"got {new_input_addresses}")
|
||||||
|
|
||||||
|
entry.cudagraph.replay()
|
||||||
|
return entry.output
|
||||||
@ -2,21 +2,15 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from contextlib import ExitStack
|
from typing import Any, Callable
|
||||||
from typing import Any, Callable, Optional
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.backends import VllmBackend
|
from vllm.compilation.backends import VllmBackend
|
||||||
from vllm.compilation.counter import compilation_counter
|
|
||||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import weak_ref_tensors
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -24,44 +18,29 @@ logger = init_logger(__name__)
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ConcreteSizeEntry:
|
class ConcreteSizeEntry:
|
||||||
runtime_shape: int
|
runtime_shape: int
|
||||||
need_to_compile: bool # the size is in compile_sizes
|
|
||||||
use_cudagraph: bool # the size is in cudagraph_capture_sizes
|
|
||||||
|
|
||||||
compiled: bool = False
|
compiled: bool = False
|
||||||
runnable: Callable = None # type: ignore
|
runnable: Callable = None # type: ignore
|
||||||
num_finished_warmup: int = 0
|
|
||||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
|
||||||
output: Optional[Any] = None
|
|
||||||
|
|
||||||
# for cudagraph debugging, track the input addresses
|
|
||||||
# during capture, and check if they are the same during replay
|
|
||||||
input_addresses: Optional[list[int]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAPiecewiseBackend:
|
class PiecewiseBackend:
|
||||||
|
|
||||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||||
graph_pool: Any, piecewise_compile_index: int,
|
piecewise_compile_index: int, total_piecewise_compiles: int,
|
||||||
total_piecewise_compiles: int, sym_shape_indices: list[int],
|
sym_shape_indices: list[int],
|
||||||
compiled_graph_for_general_shape: Callable,
|
compiled_graph_for_general_shape: Callable,
|
||||||
vllm_backend: VllmBackend):
|
vllm_backend: VllmBackend):
|
||||||
"""
|
"""
|
||||||
The backend for piecewise compilation.
|
The backend for piecewise compilation.
|
||||||
It mainly handles the compilation and cudagraph capturing.
|
It mainly handles the compilation of static shapes and
|
||||||
|
dispatching based on runtime shape.
|
||||||
|
|
||||||
We will compile `self.graph` once for the general shape,
|
We will compile `self.graph` once for the general shape,
|
||||||
and then compile for different shapes specified in
|
and then compile for different shapes specified in
|
||||||
`compilation_config.compile_sizes`.
|
`compilation_config.compile_sizes`.
|
||||||
|
|
||||||
Independently, we will capture cudagraph for different shapes.
|
|
||||||
|
|
||||||
If a shape needs both compilation and cudagraph, we will
|
|
||||||
compile it first, and then capture cudagraph.
|
|
||||||
"""
|
"""
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
self.graph_pool = graph_pool
|
|
||||||
self.piecewise_compile_index = piecewise_compile_index
|
self.piecewise_compile_index = piecewise_compile_index
|
||||||
self.total_piecewise_compiles = total_piecewise_compiles
|
self.total_piecewise_compiles = total_piecewise_compiles
|
||||||
self.vllm_backend = vllm_backend
|
self.vllm_backend = vllm_backend
|
||||||
@ -70,11 +49,10 @@ class CUDAPiecewiseBackend:
|
|||||||
self.is_last_graph = (
|
self.is_last_graph = (
|
||||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||||
|
|
||||||
|
self.is_full_graph = total_piecewise_compiles == 1
|
||||||
|
|
||||||
self.compile_sizes: set[int] = set(
|
self.compile_sizes: set[int] = set(
|
||||||
self.compilation_config.compile_sizes)
|
self.compilation_config.compile_sizes)
|
||||||
self.cudagraph_capture_sizes: set[int] = set(
|
|
||||||
self.compilation_config.cudagraph_capture_sizes
|
|
||||||
) if self.compilation_config.use_cudagraph else set()
|
|
||||||
|
|
||||||
self.first_run_finished = False
|
self.first_run_finished = False
|
||||||
|
|
||||||
@ -84,18 +62,18 @@ class CUDAPiecewiseBackend:
|
|||||||
|
|
||||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||||
|
|
||||||
# the entries for different shapes that we need to either
|
# the entries for different shapes that we need to compile
|
||||||
# compile or capture cudagraph
|
|
||||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||||
|
|
||||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||||
# and updates during the compilation process, so we need to copy it
|
# and updates during the compilation process, so we need to copy it
|
||||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||||
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
|
||||||
|
# We only keep compilation management inside this class directly.
|
||||||
|
for shape in self.compile_sizes:
|
||||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||||
runtime_shape=shape,
|
runtime_shape=shape,
|
||||||
need_to_compile=shape in self.compile_sizes,
|
runnable=self.compiled_graph_for_general_shape,
|
||||||
use_cudagraph=shape in self.cudagraph_capture_sizes,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_for_ending_compilation(self):
|
def check_for_ending_compilation(self):
|
||||||
@ -112,16 +90,14 @@ class CUDAPiecewiseBackend:
|
|||||||
return self.compiled_graph_for_general_shape(*args)
|
return self.compiled_graph_for_general_shape(*args)
|
||||||
|
|
||||||
runtime_shape = args[self.sym_shape_indices[0]]
|
runtime_shape = args[self.sym_shape_indices[0]]
|
||||||
|
|
||||||
if runtime_shape not in self.concrete_size_entries:
|
if runtime_shape not in self.concrete_size_entries:
|
||||||
# we don't need to do anything for this shape
|
# we don't need to do anything for this shape
|
||||||
return self.compiled_graph_for_general_shape(*args)
|
return self.compiled_graph_for_general_shape(*args)
|
||||||
|
|
||||||
entry = self.concrete_size_entries[runtime_shape]
|
entry = self.concrete_size_entries[runtime_shape]
|
||||||
|
|
||||||
if entry.runnable is None:
|
if not entry.compiled:
|
||||||
entry.runnable = self.compiled_graph_for_general_shape
|
|
||||||
|
|
||||||
if entry.need_to_compile and not entry.compiled:
|
|
||||||
entry.compiled = True
|
entry.compiled = True
|
||||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||||
# args are real arguments
|
# args are real arguments
|
||||||
@ -138,81 +114,4 @@ class CUDAPiecewiseBackend:
|
|||||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||||
self.check_for_ending_compilation()
|
self.check_for_ending_compilation()
|
||||||
|
|
||||||
# Skip CUDA graphs if this entry doesn't use them OR
|
return entry.runnable(*args)
|
||||||
# if we're supposed to skip them globally
|
|
||||||
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
|
|
||||||
if not entry.use_cudagraph or skip_cuda_graphs:
|
|
||||||
return entry.runnable(*args)
|
|
||||||
|
|
||||||
if entry.cudagraph is None:
|
|
||||||
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
|
|
||||||
entry.num_finished_warmup += 1
|
|
||||||
if self.is_first_graph:
|
|
||||||
logger.debug(
|
|
||||||
"Warming up %s/%s for shape %s",
|
|
||||||
entry.num_finished_warmup,
|
|
||||||
self.compilation_config.cudagraph_num_of_warmups,
|
|
||||||
runtime_shape)
|
|
||||||
return entry.runnable(*args)
|
|
||||||
|
|
||||||
if self.is_first_graph:
|
|
||||||
# Since we capture cudagraph for many different shapes and
|
|
||||||
# capturing is fast, we don't need to log it for every shape.
|
|
||||||
# We only log it in the debug mode.
|
|
||||||
logger.debug("Capturing a cudagraph for shape %s",
|
|
||||||
runtime_shape)
|
|
||||||
|
|
||||||
input_addresses = [
|
|
||||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
|
||||||
]
|
|
||||||
entry.input_addresses = input_addresses
|
|
||||||
cudagraph = torch.cuda.CUDAGraph()
|
|
||||||
|
|
||||||
with ExitStack() as stack:
|
|
||||||
if not self.is_first_graph:
|
|
||||||
# during every model forward, we will capture
|
|
||||||
# many pieces of cudagraphs (roughly one per layer).
|
|
||||||
# running gc again and again across layers will
|
|
||||||
# make the cudagraph capture very slow.
|
|
||||||
# therefore, we only run gc for the first graph,
|
|
||||||
# and disable gc for the rest of the graphs.
|
|
||||||
stack.enter_context(patch("gc.collect", lambda: None))
|
|
||||||
stack.enter_context(
|
|
||||||
patch("torch.cuda.empty_cache", lambda: None))
|
|
||||||
|
|
||||||
# mind-exploding: carefully manage the reference and memory.
|
|
||||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
|
||||||
# `output` is managed by pytorch's cudagraph pool
|
|
||||||
output = entry.runnable(*args)
|
|
||||||
if self.is_last_graph:
|
|
||||||
# by converting it to weak ref,
|
|
||||||
# the original `output` will immediately be released
|
|
||||||
# to save memory. It is only safe to do this for
|
|
||||||
# the last graph, because the output of the last graph
|
|
||||||
# will not be used by any other cuda graph.
|
|
||||||
output = weak_ref_tensors(output)
|
|
||||||
|
|
||||||
# here we always use weak ref for the output
|
|
||||||
# to save memory
|
|
||||||
entry.output = weak_ref_tensors(output)
|
|
||||||
entry.cudagraph = cudagraph
|
|
||||||
|
|
||||||
compilation_counter.num_cudagraph_captured += 1
|
|
||||||
|
|
||||||
# important: we need to return the output, rather than
|
|
||||||
# the weak ref of the output, so that pytorch can correctly
|
|
||||||
# manage the memory during cuda graph capture
|
|
||||||
return output
|
|
||||||
|
|
||||||
if self.is_debugging_mode:
|
|
||||||
# check if the input addresses are the same
|
|
||||||
new_input_addresses = [
|
|
||||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
|
||||||
]
|
|
||||||
assert new_input_addresses == entry.input_addresses, (
|
|
||||||
"Input addresses for cudagraphs are different during replay."
|
|
||||||
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
|
||||||
)
|
|
||||||
|
|
||||||
entry.cudagraph.replay()
|
|
||||||
return entry.output
|
|
||||||
|
|||||||
@ -37,3 +37,21 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
|||||||
if context_manager is not None:
|
if context_manager is not None:
|
||||||
context_manager.__exit__(None, None, None)
|
context_manager.__exit__(None, None, None)
|
||||||
context_manager = None
|
context_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
cudagraph_capturing_enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
def validate_cudagraph_capturing_enabled():
|
||||||
|
# used to monitor whether an cudagraph capturing is legal at runtime.
|
||||||
|
# should be called before any cudagraph capturing.
|
||||||
|
# if an illegal cudagraph capturing happens, raise an error.
|
||||||
|
global cudagraph_capturing_enabled
|
||||||
|
if not cudagraph_capturing_enabled:
|
||||||
|
raise RuntimeError("CUDA graph capturing detected at an inappropriate "
|
||||||
|
"time. This operation is currently disabled.")
|
||||||
|
|
||||||
|
|
||||||
|
def set_cudagraph_capturing_enabled(enabled: bool):
|
||||||
|
global cudagraph_capturing_enabled
|
||||||
|
cudagraph_capturing_enabled = enabled
|
||||||
|
|||||||
@ -11,7 +11,8 @@ from typing import Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
from vllm.config import (CompilationLevel, CUDAGraphMode,
|
||||||
|
get_current_vllm_config)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -115,8 +116,8 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.vllm_config.compilation_config.use_cudagraph and \
|
if self.vllm_config.compilation_config.cudagraph_mode != \
|
||||||
"update" in new_code.co_names:
|
CUDAGraphMode.NONE and "update" in new_code.co_names:
|
||||||
import depyf
|
import depyf
|
||||||
src = depyf.decompile(new_code)
|
src = depyf.decompile(new_code)
|
||||||
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
|
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
|
||||||
|
|||||||
@ -29,10 +29,10 @@ from typing_extensions import Self, assert_never, runtime_checkable
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import version
|
from vllm import version
|
||||||
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType,
|
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
||||||
PrefixCachingHashAlgo)
|
PrefixCachingHashAlgo)
|
||||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||||
PassConfig)
|
CUDAGraphMode, PassConfig)
|
||||||
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
|
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
|
||||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||||
from vllm.config.utils import ConfigType, config
|
from vllm.config.utils import ConfigType, config
|
||||||
@ -388,6 +388,10 @@ class ModelConfig:
|
|||||||
interleave_mm_strings: bool = False
|
interleave_mm_strings: bool = False
|
||||||
"""Enable fully interleaved support for multimodal prompts, while using
|
"""Enable fully interleaved support for multimodal prompts, while using
|
||||||
--chat-template-content-format=string. Defaults to False."""
|
--chat-template-content-format=string. Defaults to False."""
|
||||||
|
skip_mm_profiling: bool = False
|
||||||
|
"""When enabled, skips multimodal memory profiling and only profiles with
|
||||||
|
language backbone model during engine initialization.
|
||||||
|
"""
|
||||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
"""Additional args passed to process media inputs, keyed by modalities.
|
"""Additional args passed to process media inputs, keyed by modalities.
|
||||||
For example, to set num_frames for video, set
|
For example, to set num_frames for video, set
|
||||||
@ -837,7 +841,8 @@ class ModelConfig:
|
|||||||
media_io_kwargs=self.media_io_kwargs,
|
media_io_kwargs=self.media_io_kwargs,
|
||||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||||
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
||||||
interleave_mm_strings=self.interleave_mm_strings)
|
interleave_mm_strings=self.interleave_mm_strings,
|
||||||
|
skip_mm_profiling=self.skip_mm_profiling)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -2511,6 +2516,16 @@ class MultiModalConfig:
|
|||||||
Enable fully interleaved support for multimodal prompts.
|
Enable fully interleaved support for multimodal prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
skip_mm_profiling: bool = False
|
||||||
|
"""
|
||||||
|
When enabled, skips multimodal memory profiling and only profiles with
|
||||||
|
language backbone model during engine initialization.
|
||||||
|
|
||||||
|
This reduces engine startup time but shifts the responsibility to users for
|
||||||
|
estimating the peak memory usage of the activation of multimodal encoder and
|
||||||
|
embedding cache.
|
||||||
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
@ -3514,11 +3529,21 @@ class VllmConfig:
|
|||||||
else:
|
else:
|
||||||
self.compilation_config.level = \
|
self.compilation_config.level = \
|
||||||
CompilationLevel.NO_COMPILATION
|
CompilationLevel.NO_COMPILATION
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# NB: Passing both --enforce-eager and a compilation level
|
# NB: Passing both --enforce-eager and a compilation level
|
||||||
# in V0 means the compilation level wins out.
|
# in V0 means the compilation level wins out.
|
||||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
|
|
||||||
|
# if cudagraph_mode is not explicitly set by users, set default value
|
||||||
|
if self.compilation_config.cudagraph_mode is None:
|
||||||
|
if envs.VLLM_USE_V1 and self.compilation_config.level \
|
||||||
|
== CompilationLevel.PIECEWISE:
|
||||||
|
self.compilation_config.cudagraph_mode = \
|
||||||
|
CUDAGraphMode.PIECEWISE
|
||||||
|
else:
|
||||||
|
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
# async tp is built on top of sequence parallelism
|
# async tp is built on top of sequence parallelism
|
||||||
# and requires it to be enabled.
|
# and requires it to be enabled.
|
||||||
if self.compilation_config.pass_config.enable_async_tp:
|
if self.compilation_config.pass_config.enable_async_tp:
|
||||||
@ -3526,12 +3551,13 @@ class VllmConfig:
|
|||||||
True
|
True
|
||||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||||
self.compilation_config.custom_ops.append("+rms_norm")
|
self.compilation_config.custom_ops.append("+rms_norm")
|
||||||
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
|
||||||
not self.model_config.enforce_eager:
|
# disable cudagraph when enforce eager execution
|
||||||
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
|
if self.model_config is not None and self.model_config.enforce_eager:
|
||||||
# is set to True, full CUDA graphs will be used.
|
logger.info("Cudagraph is disabled under eager mode")
|
||||||
|
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
|
elif envs.VLLM_USE_V1:
|
||||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||||
self.compilation_config.set_splitting_ops_for_v1()
|
|
||||||
|
|
||||||
self._set_cudagraph_sizes()
|
self._set_cudagraph_sizes()
|
||||||
|
|
||||||
@ -3551,12 +3577,6 @@ class VllmConfig:
|
|||||||
"Disabling `torch.compile`.")
|
"Disabling `torch.compile`.")
|
||||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
|
|
||||||
if self.compilation_config.full_cuda_graph and \
|
|
||||||
not self.model_config.disable_cascade_attn:
|
|
||||||
logger.info("full_cuda_graph is not supported with "
|
|
||||||
"cascade attention. Disabling cascade attention.")
|
|
||||||
self.model_config.disable_cascade_attn = True
|
|
||||||
|
|
||||||
disable_chunked_prefill_reasons: list[str] = []
|
disable_chunked_prefill_reasons: list[str] = []
|
||||||
|
|
||||||
if self.model_config and self.model_config.pooler_config:
|
if self.model_config and self.model_config.pooler_config:
|
||||||
@ -3597,9 +3617,32 @@ class VllmConfig:
|
|||||||
"to True to enable.")
|
"to True to enable.")
|
||||||
current_platform.check_and_update_config(self)
|
current_platform.check_and_update_config(self)
|
||||||
|
|
||||||
|
# final check of cudagraph mode after platform-specific update
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
|
||||||
|
and self.model_config is not None and \
|
||||||
|
not self.model_config.disable_cascade_attn:
|
||||||
|
logger.info("CUDAGraphMode.FULL is not supported with "
|
||||||
|
"cascade attention currently. Disabling cascade"
|
||||||
|
"attention.")
|
||||||
|
self.model_config.disable_cascade_attn = True
|
||||||
|
|
||||||
|
if self.compilation_config.cudagraph_mode\
|
||||||
|
.requires_piecewise_compilation():
|
||||||
|
assert self.compilation_config.level == \
|
||||||
|
CompilationLevel.PIECEWISE, \
|
||||||
|
"Compilation level should be CompilationLevel.PIECEWISE "\
|
||||||
|
"when cudagraph_mode piecewise cudagraphs is used, "\
|
||||||
|
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||||
|
|
||||||
if not self.instance_id:
|
if not self.instance_id:
|
||||||
self.instance_id = random_uuid()[:5]
|
self.instance_id = random_uuid()[:5]
|
||||||
|
|
||||||
|
# Do this after all the updates to compilation_config.level
|
||||||
|
if envs.VLLM_USE_V1 and \
|
||||||
|
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||||
|
self.compilation_config.set_splitting_ops_for_v1()
|
||||||
|
|
||||||
if (envs.VLLM_USE_V1
|
if (envs.VLLM_USE_V1
|
||||||
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
||||||
# logger should only print warning message for hybrid models. As we
|
# logger should only print warning message for hybrid models. As we
|
||||||
|
|||||||
@ -23,6 +23,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||||
|
MambaDType = Literal["auto", "float32"]
|
||||||
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
|
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
|
||||||
|
|
||||||
|
|
||||||
@ -93,6 +94,15 @@ class CacheConfig:
|
|||||||
""" Optional override for mamba page size; used by hybrid mamba/attention
|
""" Optional override for mamba page size; used by hybrid mamba/attention
|
||||||
models to ensure exact alignment with attention page size."""
|
models to ensure exact alignment with attention page size."""
|
||||||
|
|
||||||
|
mamba_cache_dtype: MambaDType = "auto"
|
||||||
|
"""The data type to use for the Mamba cache (both the conv as well as the
|
||||||
|
ssm state). If set to 'auto', the data type will be inferred from the model
|
||||||
|
config."""
|
||||||
|
mamba_ssm_cache_dtype: MambaDType = "auto"
|
||||||
|
"""The data type to use for the Mamba cache (ssm state only, conv state will
|
||||||
|
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
|
||||||
|
for the ssm state will be determined by mamba_cache_dtype."""
|
||||||
|
|
||||||
# Will be set after profiling.
|
# Will be set after profiling.
|
||||||
num_gpu_blocks: Optional[int] = field(default=None, init=False)
|
num_gpu_blocks: Optional[int] = field(default=None, init=False)
|
||||||
"""The number of blocks to allocate for GPU memory."""
|
"""The number of blocks to allocate for GPU memory."""
|
||||||
@ -123,6 +133,8 @@ class CacheConfig:
|
|||||||
"""
|
"""
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
factors.append(self.cache_dtype)
|
factors.append(self.cache_dtype)
|
||||||
|
factors.append(self.mamba_cache_dtype)
|
||||||
|
factors.append(self.mamba_ssm_cache_dtype)
|
||||||
# `cpu_offload_gb` does not use `torch.compile` yet.
|
# `cpu_offload_gb` does not use `torch.compile` yet.
|
||||||
hash_str = hashlib.md5(str(factors).encode(),
|
hash_str = hashlib.md5(str(factors).encode(),
|
||||||
usedforsecurity=False).hexdigest()
|
usedforsecurity=False).hexdigest()
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import enum
|
||||||
import hashlib
|
import hashlib
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from dataclasses import asdict, field
|
from dataclasses import asdict, field
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter, field_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -31,6 +32,40 @@ class CompilationLevel:
|
|||||||
PIECEWISE = 3
|
PIECEWISE = 3
|
||||||
|
|
||||||
|
|
||||||
|
class CUDAGraphMode(enum.Enum):
|
||||||
|
""" Constants for the cudagraph mode in CompilationConfig.
|
||||||
|
Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
|
||||||
|
treated as concrete runtime mode for cudagraph runtime dispatching.
|
||||||
|
"""
|
||||||
|
NONE = 0
|
||||||
|
PIECEWISE = 1
|
||||||
|
FULL = 2
|
||||||
|
FULL_DECODE_ONLY = (FULL, NONE)
|
||||||
|
FULL_AND_PIECEWISE = (FULL, PIECEWISE)
|
||||||
|
|
||||||
|
def decode_mode(self) -> 'CUDAGraphMode':
|
||||||
|
return CUDAGraphMode(self.value[0]) if \
|
||||||
|
self.separate_routine() else self
|
||||||
|
|
||||||
|
def mixed_mode(self) -> 'CUDAGraphMode':
|
||||||
|
return CUDAGraphMode(self.value[1]) if \
|
||||||
|
self.separate_routine() else self
|
||||||
|
|
||||||
|
def requires_piecewise_compilation(self) -> bool:
|
||||||
|
return (self.decode_mode() == CUDAGraphMode.PIECEWISE
|
||||||
|
or self.mixed_mode() == CUDAGraphMode.PIECEWISE)
|
||||||
|
|
||||||
|
def max_cudagraph_mode(self) -> 'CUDAGraphMode':
|
||||||
|
return CUDAGraphMode(max(
|
||||||
|
self.value)) if self.separate_routine() else self
|
||||||
|
|
||||||
|
def has_full_cudagraphs(self) -> bool:
|
||||||
|
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
def separate_routine(self) -> bool:
|
||||||
|
return isinstance(self.value, tuple)
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class PassConfig:
|
class PassConfig:
|
||||||
@ -91,6 +126,7 @@ class CompilationConfig:
|
|||||||
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
|
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
|
||||||
- CudaGraph capture:
|
- CudaGraph capture:
|
||||||
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
|
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
|
||||||
|
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
|
||||||
- [`cudagraph_capture_sizes`]
|
- [`cudagraph_capture_sizes`]
|
||||||
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
|
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
|
||||||
- [`cudagraph_num_of_warmups`]
|
- [`cudagraph_num_of_warmups`]
|
||||||
@ -157,7 +193,7 @@ class CompilationConfig:
|
|||||||
By default, all custom ops are enabled when running without Inductor and
|
By default, all custom ops are enabled when running without Inductor and
|
||||||
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
|
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
|
||||||
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
||||||
splitting_ops: list[str] = field(default_factory=list)
|
splitting_ops: Optional[list[str]] = None
|
||||||
"""A list of ops to split the full graph into subgraphs, used in piecewise
|
"""A list of ops to split the full graph into subgraphs, used in piecewise
|
||||||
compilation."""
|
compilation."""
|
||||||
|
|
||||||
@ -187,7 +223,43 @@ class CompilationConfig:
|
|||||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
||||||
|
|
||||||
# CudaGraph compilation
|
# CudaGraph compilation
|
||||||
use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1)
|
cudagraph_mode: Optional[CUDAGraphMode] = None
|
||||||
|
"""
|
||||||
|
The mode of the cudagraph.
|
||||||
|
- NONE, no cudagraph capture.
|
||||||
|
- PIECEWISE. (v1 default)
|
||||||
|
- FULL.
|
||||||
|
- FULL_DECODE_ONLY.
|
||||||
|
- FULL_AND_PIECEWISE.
|
||||||
|
|
||||||
|
PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
|
||||||
|
incompatiable ops (i.e. some attention ops) outside the cudagraph
|
||||||
|
for general flexibility.
|
||||||
|
This is the default mode.
|
||||||
|
|
||||||
|
FULL mode: Capture full cudagraph for all batches. Can be good for small
|
||||||
|
models or workloads with small prompts; not supported by many backends.
|
||||||
|
Generally for performance FULL_AND_PIECEWISE is better.
|
||||||
|
|
||||||
|
FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
|
||||||
|
Mixed prefill-decode batches are run without cudagraphs. Can be good for
|
||||||
|
decode instances in a P/D setup where prefill is not as important so we
|
||||||
|
can save some memory.
|
||||||
|
|
||||||
|
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
|
||||||
|
piecewise cudagraph for prefill and mixed prefill-decode batches.
|
||||||
|
This is like the most performant mode for most models.
|
||||||
|
|
||||||
|
Currently, the cudagraph mode is only used for the v1 engine.
|
||||||
|
Note that the cudagraph logic is generally orthogonal to the
|
||||||
|
compilation logic. While piecewise cudagraphs require piecewise
|
||||||
|
compilation (level=PIECEWISE and non-empty splitting_ops), full
|
||||||
|
cudagraphs are supported with and without compilation.
|
||||||
|
|
||||||
|
Warning: This flag is new and subject to change in addition
|
||||||
|
more modes may be added.
|
||||||
|
"""
|
||||||
|
use_cudagraph: bool = True
|
||||||
"""Whether to use cudagraph inside compilation.
|
"""Whether to use cudagraph inside compilation.
|
||||||
- False: cudagraph inside compilation is not used.
|
- False: cudagraph inside compilation is not used.
|
||||||
- True: cudagraph inside compilation is used. It requires
|
- True: cudagraph inside compilation is used. It requires
|
||||||
@ -197,8 +269,9 @@ class CompilationConfig:
|
|||||||
CompilationLevel.PIECEWISE (aka -O3).
|
CompilationLevel.PIECEWISE (aka -O3).
|
||||||
Note that this is orthogonal to the cudagraph capture logic
|
Note that this is orthogonal to the cudagraph capture logic
|
||||||
outside of compilation.
|
outside of compilation.
|
||||||
TODO: move outside cudagraph logic into compilation.
|
Warning: This flag is deprecated and will be removed in the next major or
|
||||||
torch.compile will handle cudagraph capture logic in the future."""
|
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
||||||
|
"""
|
||||||
cudagraph_num_of_warmups: int = 0
|
cudagraph_num_of_warmups: int = 0
|
||||||
"""Number of warmup runs for cudagraph.
|
"""Number of warmup runs for cudagraph.
|
||||||
It means the first several runs will be treated as warmup runs.
|
It means the first several runs will be treated as warmup runs.
|
||||||
@ -213,12 +286,17 @@ class CompilationConfig:
|
|||||||
cudagraph. If the caller can guarantee that the same input buffers
|
cudagraph. If the caller can guarantee that the same input buffers
|
||||||
are always used, it can set this to False. Otherwise, it should
|
are always used, it can set this to False. Otherwise, it should
|
||||||
set this to True, and the compiler will copy the input to an
|
set this to True, and the compiler will copy the input to an
|
||||||
internally managed buffer. Default is False."""
|
internally managed buffer. Default is False.
|
||||||
full_cuda_graph: bool = False
|
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
|
||||||
|
"""
|
||||||
|
full_cuda_graph: Optional[bool] = False
|
||||||
"""whether to use a full cuda graph for the entire forward pass rather than
|
"""whether to use a full cuda graph for the entire forward pass rather than
|
||||||
splitting certain operations such as attention into subgraphs. Thus this
|
splitting certain operations such as attention into subgraphs. Thus this
|
||||||
flag cannot be used together with splitting_ops. This may provide
|
flag cannot be used together with splitting_ops. This may provide
|
||||||
performance benefits for smaller models."""
|
performance benefits for smaller models.
|
||||||
|
Warning: This flag is deprecated and will be removed in the next major or
|
||||||
|
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
||||||
|
"""
|
||||||
|
|
||||||
pass_config: PassConfig = field(default_factory=PassConfig)
|
pass_config: PassConfig = field(default_factory=PassConfig)
|
||||||
"""Custom inductor passes, see PassConfig for more details"""
|
"""Custom inductor passes, see PassConfig for more details"""
|
||||||
@ -253,6 +331,13 @@ class CompilationConfig:
|
|||||||
Map from layer name to layer objects that need to be accessed outside
|
Map from layer name to layer objects that need to be accessed outside
|
||||||
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
||||||
|
|
||||||
|
# Attention ops; used for piecewise cudagraphs
|
||||||
|
_attention_ops: ClassVar[list[str]] = [
|
||||||
|
"vllm.unified_attention",
|
||||||
|
"vllm.unified_attention_with_output",
|
||||||
|
"vllm.mamba_mixer2",
|
||||||
|
]
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
@ -297,13 +382,26 @@ class CompilationConfig:
|
|||||||
if pass_config_exclude:
|
if pass_config_exclude:
|
||||||
exclude["pass_config"] = pass_config_exclude
|
exclude["pass_config"] = pass_config_exclude
|
||||||
|
|
||||||
return TypeAdapter(CompilationConfig).dump_json(
|
# The cast to string is necessary because Pydantic is mocked in docs
|
||||||
self,
|
# builds and sphinx-argparse doesn't know the return type of decode()
|
||||||
exclude=exclude, # type: ignore[arg-type]
|
return str(
|
||||||
exclude_unset=True).decode()
|
TypeAdapter(CompilationConfig).dump_json(
|
||||||
|
self,
|
||||||
|
exclude=exclude, # type: ignore[arg-type]
|
||||||
|
exclude_unset=True).decode())
|
||||||
|
|
||||||
__str__ = __repr__
|
__str__ = __repr__
|
||||||
|
|
||||||
|
@field_validator("cudagraph_mode", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
|
||||||
|
"""
|
||||||
|
enable parse the `cudagraph_mode` enum type from string
|
||||||
|
"""
|
||||||
|
if isinstance(value, str):
|
||||||
|
return CUDAGraphMode[value.upper()]
|
||||||
|
return value
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
count_none = self.custom_ops.count("none")
|
count_none = self.custom_ops.count("none")
|
||||||
count_all = self.custom_ops.count("all")
|
count_all = self.custom_ops.count("all")
|
||||||
@ -341,7 +439,26 @@ class CompilationConfig:
|
|||||||
if isinstance(self.pass_config, dict):
|
if isinstance(self.pass_config, dict):
|
||||||
self.pass_config = PassConfig(**self.pass_config)
|
self.pass_config = PassConfig(**self.pass_config)
|
||||||
|
|
||||||
def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]:
|
# migrate the deprecated flags
|
||||||
|
if not self.use_cudagraph:
|
||||||
|
logger.warning("use_cudagraph is deprecated, use "
|
||||||
|
"cudagraph_mode=NONE instead.")
|
||||||
|
if self.cudagraph_mode is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"use_cudagraph and cudagraph_mode are mutually"
|
||||||
|
" exclusive, prefer cudagraph_mode since "
|
||||||
|
"use_cudagraph is deprecated.")
|
||||||
|
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
|
if self.full_cuda_graph:
|
||||||
|
logger.warning("full_cuda_graph is deprecated, use "
|
||||||
|
"cudagraph_mode=FULL instead.")
|
||||||
|
if self.cudagraph_mode is not None:
|
||||||
|
raise ValueError("full_cuda_graph and cudagraph_mode are "
|
||||||
|
"mutually exclusive, prefer cudagraph_mode "
|
||||||
|
"since full_cuda_graph is deprecated.")
|
||||||
|
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||||
if self.level == CompilationLevel.NO_COMPILATION:
|
if self.level == CompilationLevel.NO_COMPILATION:
|
||||||
raise ValueError("No compilation level is set.")
|
raise ValueError("No compilation level is set.")
|
||||||
|
|
||||||
@ -414,15 +531,34 @@ class CompilationConfig:
|
|||||||
self.max_capture_size] = self.max_capture_size
|
self.max_capture_size] = self.max_capture_size
|
||||||
|
|
||||||
def set_splitting_ops_for_v1(self):
|
def set_splitting_ops_for_v1(self):
|
||||||
# NOTE: this function needs to be called
|
# NOTE: this function needs to be called only when level is
|
||||||
if self.splitting_ops and self.full_cuda_graph:
|
# CompilationLevel.PIECEWISE
|
||||||
raise ValueError("full_cuda_graph cannot be used together with "
|
assert self.level == CompilationLevel.PIECEWISE, (
|
||||||
"splitting_ops, as Full CUDA graph will override "
|
"set_splitting_ops_for_v1 should only be called when "
|
||||||
f"the splitting_ops: {self.splitting_ops}")
|
"level is CompilationLevel.PIECEWISE")
|
||||||
|
|
||||||
if not self.splitting_ops:
|
if self.splitting_ops is None:
|
||||||
self.splitting_ops = [] if self.full_cuda_graph else [
|
# NOTE: When using full cudagraph, instead of setting an empty
|
||||||
"vllm.unified_attention",
|
# list and capture the full cudagraph inside the flattened fx
|
||||||
"vllm.unified_attention_with_output",
|
# graph, we keep the piecewise fx graph structure but capture the
|
||||||
"vllm.mamba_mixer2",
|
# full cudagraph outside the fx graph. This reduces some cpu
|
||||||
]
|
# overhead when the runtime batch_size is not cudagraph captured.
|
||||||
|
# see https://github.com/vllm-project/vllm/pull/20059 for details.
|
||||||
|
self.splitting_ops = self._attention_ops
|
||||||
|
elif len(self.splitting_ops) == 0:
|
||||||
|
logger.warning_once("Using piecewise compilation with empty "
|
||||||
|
"splitting_ops.")
|
||||||
|
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||||
|
logger.warning_once(
|
||||||
|
"When compilation level is piecewise with empty "
|
||||||
|
"splitting_ops, PIECEWISE cudagraph_mode will be "
|
||||||
|
"treated as FULL cudagraph_mode. Please ensure you are "
|
||||||
|
"using attention backends that support cudagraph or set "
|
||||||
|
"cudagraph_mode to NONE explicitly if encountering "
|
||||||
|
"any problems.")
|
||||||
|
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||||
|
self.splitting_ops = []
|
||||||
|
|
||||||
|
def splitting_ops_contain_attention(self) -> bool:
|
||||||
|
return self.splitting_ops is not None and all(
|
||||||
|
op in self.splitting_ops for op in self._attention_ops)
|
||||||
|
|||||||
@ -325,4 +325,8 @@ class KVConnectorBase_V1(ABC):
|
|||||||
str: the required KV cache layout. e.g. HND, or NHD.
|
str: the required KV cache layout. e.g. HND, or NHD.
|
||||||
None if the connector does not require a specific layout.
|
None if the connector does not require a specific layout.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if cls is KVConnectorBase_V1:
|
||||||
|
raise TypeError("get_required_kvcache_layout should not be called "
|
||||||
|
"on the abstract base class")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -228,9 +228,10 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
for ktc in ktcs:
|
for ktc in ktcs:
|
||||||
kv_transfer_config = KVTransferConfig(**ktc)
|
kv_transfer_config = KVTransferConfig(**ktc)
|
||||||
temp_vllm_config.kv_transfer_config = kv_transfer_config
|
temp_vllm_config.kv_transfer_config = kv_transfer_config
|
||||||
|
connector_cls = KVConnectorFactory.get_connector_class(
|
||||||
|
kv_transfer_config)
|
||||||
required_kvcache_layout = (
|
required_kvcache_layout = (
|
||||||
KVConnectorBase_V1.get_required_kvcache_layout(
|
connector_cls.get_required_kvcache_layout(temp_vllm_config))
|
||||||
temp_vllm_config))
|
|
||||||
if required_kvcache_layout is not None:
|
if required_kvcache_layout is not None:
|
||||||
layouts.add(required_kvcache_layout)
|
layouts.add(required_kvcache_layout)
|
||||||
|
|
||||||
|
|||||||
@ -27,12 +27,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
|||||||
DeviceConfig, DistributedExecutorBackend,
|
DeviceConfig, DistributedExecutorBackend,
|
||||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||||
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
LoRAConfig, MambaDType, ModelConfig, ModelDType,
|
||||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||||
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
|
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
RunnerOption, SchedulerConfig, SchedulerPolicy,
|
||||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||||
get_field)
|
VllmConfig, get_attr_docs, get_field)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
@ -350,6 +350,7 @@ class EngineArgs:
|
|||||||
MultiModalConfig.mm_processor_kwargs
|
MultiModalConfig.mm_processor_kwargs
|
||||||
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
||||||
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
|
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
|
||||||
|
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||||
# LoRA fields
|
# LoRA fields
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||||
@ -421,6 +422,8 @@ class EngineArgs:
|
|||||||
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
||||||
|
|
||||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||||
|
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||||
|
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
|
||||||
|
|
||||||
additional_config: dict[str, Any] = \
|
additional_config: dict[str, Any] = \
|
||||||
get_field(VllmConfig, "additional_config")
|
get_field(VllmConfig, "additional_config")
|
||||||
@ -693,6 +696,10 @@ class EngineArgs:
|
|||||||
**cache_kwargs["calculate_kv_scales"])
|
**cache_kwargs["calculate_kv_scales"])
|
||||||
cache_group.add_argument("--kv-sharing-fast-prefill",
|
cache_group.add_argument("--kv-sharing-fast-prefill",
|
||||||
**cache_kwargs["kv_sharing_fast_prefill"])
|
**cache_kwargs["kv_sharing_fast_prefill"])
|
||||||
|
cache_group.add_argument("--mamba-cache-dtype",
|
||||||
|
**cache_kwargs["mamba_cache_dtype"])
|
||||||
|
cache_group.add_argument("--mamba-ssm-cache-dtype",
|
||||||
|
**cache_kwargs["mamba_ssm_cache_dtype"])
|
||||||
|
|
||||||
# Multimodal related configs
|
# Multimodal related configs
|
||||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||||
@ -716,6 +723,8 @@ class EngineArgs:
|
|||||||
multimodal_group.add_argument(
|
multimodal_group.add_argument(
|
||||||
"--interleave-mm-strings",
|
"--interleave-mm-strings",
|
||||||
**multimodal_kwargs["interleave_mm_strings"])
|
**multimodal_kwargs["interleave_mm_strings"])
|
||||||
|
multimodal_group.add_argument("--skip-mm-profiling",
|
||||||
|
**multimodal_kwargs["skip_mm_profiling"])
|
||||||
|
|
||||||
# LoRA related configs
|
# LoRA related configs
|
||||||
lora_kwargs = get_kwargs(LoRAConfig)
|
lora_kwargs = get_kwargs(LoRAConfig)
|
||||||
@ -918,6 +927,7 @@ class EngineArgs:
|
|||||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||||
interleave_mm_strings=self.interleave_mm_strings,
|
interleave_mm_strings=self.interleave_mm_strings,
|
||||||
media_io_kwargs=self.media_io_kwargs,
|
media_io_kwargs=self.media_io_kwargs,
|
||||||
|
skip_mm_profiling=self.skip_mm_profiling,
|
||||||
use_async_output_proc=not self.disable_async_output_proc,
|
use_async_output_proc=not self.disable_async_output_proc,
|
||||||
config_format=self.config_format,
|
config_format=self.config_format,
|
||||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||||
@ -1101,6 +1111,8 @@ class EngineArgs:
|
|||||||
cpu_offload_gb=self.cpu_offload_gb,
|
cpu_offload_gb=self.cpu_offload_gb,
|
||||||
calculate_kv_scales=self.calculate_kv_scales,
|
calculate_kv_scales=self.calculate_kv_scales,
|
||||||
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
||||||
|
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||||
|
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
ray_runtime_env = None
|
ray_runtime_env = None
|
||||||
|
|||||||
@ -126,7 +126,7 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
async def _force_log():
|
async def _force_log():
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(10.)
|
await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
|
||||||
await engine_client.do_log_stats()
|
await engine_client.do_log_stats()
|
||||||
|
|
||||||
task = asyncio.create_task(_force_log())
|
task = asyncio.create_task(_force_log())
|
||||||
|
|||||||
25
vllm/envs.py
25
vllm/envs.py
@ -38,6 +38,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_LOGGING_PREFIX: str = ""
|
VLLM_LOGGING_PREFIX: str = ""
|
||||||
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
|
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
|
||||||
VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None
|
VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None
|
||||||
|
VLLM_LOG_STATS_INTERVAL: float = 10.
|
||||||
VLLM_TRACE_FUNCTION: int = 0
|
VLLM_TRACE_FUNCTION: int = 0
|
||||||
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||||
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
|
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
|
||||||
@ -122,6 +123,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_MOE_DP_CHUNK_SIZE: int = 256
|
VLLM_MOE_DP_CHUNK_SIZE: int = 256
|
||||||
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
||||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||||
|
VLLM_MXFP4_USE_MARLIN: Optional[bool] = None
|
||||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||||
VLLM_V1_USE_OUTLINES_CACHE: bool = False
|
VLLM_V1_USE_OUTLINES_CACHE: bool = False
|
||||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||||
@ -182,6 +184,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
|
|||||||
return int(value)
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_convert_bool(value: Optional[str]) -> Optional[bool]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return bool(int(value))
|
||||||
|
|
||||||
|
|
||||||
def get_vllm_port() -> Optional[int]:
|
def get_vllm_port() -> Optional[int]:
|
||||||
"""Get the port from VLLM_PORT environment variable.
|
"""Get the port from VLLM_PORT environment variable.
|
||||||
|
|
||||||
@ -429,6 +437,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0"))
|
lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0"))
|
||||||
if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None,
|
if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None,
|
||||||
|
|
||||||
|
# If set, vllm will log stats at this interval in seconds
|
||||||
|
# If not set, vllm will log stats every 10 seconds.
|
||||||
|
"VLLM_LOG_STATS_INTERVAL":
|
||||||
|
lambda: val if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10.")))
|
||||||
|
> 0. else 10.,
|
||||||
|
|
||||||
# Trace function calls
|
# Trace function calls
|
||||||
# If set to 1, vllm will trace function calls
|
# If set to 1, vllm will trace function calls
|
||||||
# Useful for debugging
|
# Useful for debugging
|
||||||
@ -906,6 +920,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_MARLIN_USE_ATOMIC_ADD":
|
"VLLM_MARLIN_USE_ATOMIC_ADD":
|
||||||
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
|
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
|
||||||
|
|
||||||
|
# Whether to use marlin kernel in mxfp4 quantization method
|
||||||
|
"VLLM_MXFP4_USE_MARLIN":
|
||||||
|
lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)),
|
||||||
|
|
||||||
# Whether to turn on the outlines cache for V0
|
# Whether to turn on the outlines cache for V0
|
||||||
# This cache is unbounded and on disk, so it's not safe to use in
|
# This cache is unbounded and on disk, so it's not safe to use in
|
||||||
# an environment with potentially malicious users.
|
# an environment with potentially malicious users.
|
||||||
@ -1090,6 +1108,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_USE_TRTLLM_ATTENTION":
|
"VLLM_USE_TRTLLM_ATTENTION":
|
||||||
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
|
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
|
||||||
|
|
||||||
|
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
|
||||||
|
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
|
||||||
|
# vllm cutlass GEMM, marlin GEMM.
|
||||||
|
"VLLM_USE_TRTLLM_FP4_GEMM":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))),
|
||||||
|
|
||||||
# Controls garbage collection during CUDA graph capture.
|
# Controls garbage collection during CUDA graph capture.
|
||||||
# If set to 0 (default), enables GC freezing to speed up capture time.
|
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||||
# If set to 1, allows GC to run during capture.
|
# If set to 1, allows GC to run during capture.
|
||||||
@ -1197,6 +1221,7 @@ def compute_hash() -> str:
|
|||||||
"VLLM_DP_SIZE",
|
"VLLM_DP_SIZE",
|
||||||
"VLLM_USE_STANDALONE_COMPILE",
|
"VLLM_USE_STANDALONE_COMPILE",
|
||||||
"VLLM_FUSED_MOE_CHUNK_SIZE",
|
"VLLM_FUSED_MOE_CHUNK_SIZE",
|
||||||
|
"VLLM_USE_TRTLLM_FP4_GEMM",
|
||||||
]
|
]
|
||||||
for key in environment_variables_to_hash:
|
for key in environment_variables_to_hash:
|
||||||
if key in environment_variables:
|
if key in environment_variables:
|
||||||
|
|||||||
@ -5,13 +5,13 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import ParallelConfig, VllmConfig
|
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -26,6 +26,27 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
|||||||
batchsize_forward_time: defaultdict = defaultdict(list)
|
batchsize_forward_time: defaultdict = defaultdict(list)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDescriptor(NamedTuple):
|
||||||
|
"""
|
||||||
|
Batch descriptor for cudagraph dispatching. We should keep the num of
|
||||||
|
items as minimal as possible to properly and uniquely describe the padded
|
||||||
|
batch for cudagraph.
|
||||||
|
"""
|
||||||
|
num_tokens: int
|
||||||
|
uniform_decode: bool = False
|
||||||
|
"""
|
||||||
|
False can also be used for an uniform decode batch to dispatch to the
|
||||||
|
cudagraph supporting non-uniform batches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def non_uniform(self) -> "BatchDescriptor":
|
||||||
|
"""
|
||||||
|
Return a non-uniform version of current batch descriptor.
|
||||||
|
"""
|
||||||
|
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
||||||
|
|
||||||
|
|
||||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
||||||
max_num_tokens: int,
|
max_num_tokens: int,
|
||||||
chunk_idx: int) -> list[int]:
|
chunk_idx: int) -> list[int]:
|
||||||
@ -152,7 +173,15 @@ class ForwardContext:
|
|||||||
virtual_engine: int # set dynamically for each forward pass
|
virtual_engine: int # set dynamically for each forward pass
|
||||||
# set dynamically for each forward pass
|
# set dynamically for each forward pass
|
||||||
dp_metadata: Optional[DPMetadata] = None
|
dp_metadata: Optional[DPMetadata] = None
|
||||||
skip_cuda_graphs: bool = False
|
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
|
||||||
|
# by default NONE, no cudagraph is used.
|
||||||
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
|
||||||
|
batch_descriptor: Optional[BatchDescriptor] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.cudagraph_runtime_mode in [
|
||||||
|
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
||||||
|
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
|
||||||
|
|
||||||
|
|
||||||
_forward_context: Optional[ForwardContext] = None
|
_forward_context: Optional[ForwardContext] = None
|
||||||
@ -168,13 +197,13 @@ def get_forward_context() -> ForwardContext:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_forward_context(
|
def set_forward_context(
|
||||||
attn_metadata: Any,
|
attn_metadata: Any,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
num_tokens: Optional[int] = None,
|
num_tokens: Optional[int] = None,
|
||||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||||
skip_cuda_graphs: bool = False,
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
):
|
batch_descriptor: Optional[BatchDescriptor] = None):
|
||||||
"""A context manager that stores the current forward context,
|
"""A context manager that stores the current forward context,
|
||||||
can be attention metadata, etc.
|
can be attention metadata, etc.
|
||||||
Here we can inject common logic for every model forward pass.
|
Here we can inject common logic for every model forward pass.
|
||||||
@ -198,7 +227,8 @@ def set_forward_context(
|
|||||||
virtual_engine=virtual_engine,
|
virtual_engine=virtual_engine,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
dp_metadata=dp_metadata,
|
dp_metadata=dp_metadata,
|
||||||
skip_cuda_graphs=skip_cuda_graphs,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
|
batch_descriptor=batch_descriptor,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -18,6 +18,8 @@ from vllm.utils import direct_register_custom_op
|
|||||||
def fused_marlin_moe(hidden_states: torch.Tensor,
|
def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
|
bias1: Optional[torch.Tensor],
|
||||||
|
bias2: Optional[torch.Tensor],
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@ -26,6 +28,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
quant_type_id: int,
|
quant_type_id: int,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
|
activation: Optional[str] = "silu",
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
global_scale1: Optional[torch.Tensor] = None,
|
global_scale1: Optional[torch.Tensor] = None,
|
||||||
global_scale2: Optional[torch.Tensor] = None,
|
global_scale2: Optional[torch.Tensor] = None,
|
||||||
@ -88,6 +91,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||||
assert num_bits in [4, 8]
|
assert num_bits in [4, 8]
|
||||||
|
assert topk_weights.dtype == torch.float32
|
||||||
|
|
||||||
M, K = hidden_states.shape
|
M, K = hidden_states.shape
|
||||||
E = w1.shape[0]
|
E = w1.shape[0]
|
||||||
@ -138,6 +142,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
w1,
|
w1,
|
||||||
|
bias1,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
global_scale1,
|
global_scale1,
|
||||||
w1_zeros,
|
w1_zeros,
|
||||||
@ -161,8 +166,28 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
use_fp32_reduce=True,
|
use_fp32_reduce=True,
|
||||||
is_zp_float=False)
|
is_zp_float=False)
|
||||||
|
|
||||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
if activation == "silu":
|
||||||
intermediate_cache1.view(-1, 2 * N))
|
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||||
|
intermediate_cache1.view(-1, 2 * N))
|
||||||
|
elif activation == "swiglu_oai":
|
||||||
|
# NOTE: in gpt-oss, the gate_proj and up_proj is interleaved
|
||||||
|
# - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
||||||
|
# - origin: gate, up = gate_up[..., :N], gate_up[..., N:]
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def swiglu_oai(gate_up):
|
||||||
|
alpha = 1.702
|
||||||
|
limit = 7.0
|
||||||
|
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
||||||
|
gate = gate.clamp(min=None, max=limit)
|
||||||
|
up = up.clamp(min=-limit, max=limit)
|
||||||
|
glu = gate * torch.sigmoid(gate * alpha)
|
||||||
|
return (up + 1) * glu
|
||||||
|
|
||||||
|
intermediate_cache2 = swiglu_oai(intermediate_cache1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported activation: {activation}. "
|
||||||
|
"Only silu and swiglu_oai activations are supported.")
|
||||||
|
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
intermediate_cache3.zero_()
|
intermediate_cache3.zero_()
|
||||||
@ -171,6 +196,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
intermediate_cache2,
|
intermediate_cache2,
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
w2,
|
w2,
|
||||||
|
bias2,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
global_scale2,
|
global_scale2,
|
||||||
w2_zeros,
|
w2_zeros,
|
||||||
|
|||||||
@ -1189,10 +1189,10 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_scale: torch.Tensor,
|
input_scale: torch.Tensor,
|
||||||
gemm1_weights: torch.Tensor,
|
gemm1_weights: torch.Tensor,
|
||||||
gemm1_weights_scale: torch.Tensor,
|
|
||||||
activation_scale: torch.Tensor,
|
|
||||||
gemm2_weights: torch.Tensor,
|
gemm2_weights: torch.Tensor,
|
||||||
gemm2_weights_scale: torch.Tensor,
|
output1_scales_scalar: torch.Tensor,
|
||||||
|
output1_scales_gate_scalar: torch.Tensor,
|
||||||
|
output2_scales_scalar: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
num_expert_group: Optional[int],
|
num_expert_group: Optional[int],
|
||||||
@ -1206,17 +1206,12 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
|||||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||||
topk_group = topk_group if topk_group is not None else 0
|
topk_group = topk_group if topk_group is not None else 0
|
||||||
|
|
||||||
quant_hidden_states, input_scale = moe_kernel_quantize_input(
|
quant_hidden_states, _ = moe_kernel_quantize_input(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
input_scale,
|
input_scale,
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
per_act_token_quant=False)
|
per_act_token_quant=False)
|
||||||
|
|
||||||
output1_scales_scalar = gemm1_weights_scale * input_scale * (
|
|
||||||
1.0 / activation_scale)
|
|
||||||
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
|
|
||||||
output2_scales_scalar = activation_scale * gemm2_weights_scale
|
|
||||||
|
|
||||||
from vllm.utils.flashinfer import (
|
from vllm.utils.flashinfer import (
|
||||||
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
||||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||||
@ -1244,24 +1239,24 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
|||||||
|
|
||||||
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||||
routing_logits: torch.Tensor,
|
routing_logits: torch.Tensor,
|
||||||
routing_bias: torch.Tensor,
|
routing_bias: Optional[torch.Tensor],
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
gemm1_weights: torch.Tensor,
|
gemm1_weights: torch.Tensor,
|
||||||
|
gemm2_weights: torch.Tensor,
|
||||||
output1_scales_scalar: torch.Tensor,
|
output1_scales_scalar: torch.Tensor,
|
||||||
output1_scales_gate_scalar: torch.Tensor,
|
output1_scales_gate_scalar: torch.Tensor,
|
||||||
gemm2_weights: torch.Tensor,
|
|
||||||
output2_scales_scalar: torch.Tensor,
|
output2_scales_scalar: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
num_expert_group: int,
|
num_expert_group: Optional[int],
|
||||||
topk_group: int,
|
topk_group: Optional[int],
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
local_expert_offset: int,
|
local_expert_offset: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
routed_scaling_factor: float = 1.0,
|
use_routing_scales_on_input: bool,
|
||||||
use_routing_scales_on_input: bool = False,
|
routing_method_type: int,
|
||||||
tile_tokens_dim: int = 8,
|
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||||
routing_method_type: int = 0) -> torch.Tensor:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
|
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
|
||||||
has_triton_kernels, is_torch_equal_or_newer, round_up)
|
round_up)
|
||||||
from vllm.utils.flashinfer import has_flashinfer
|
from vllm.utils.flashinfer import has_flashinfer
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
@ -751,19 +751,11 @@ class FusedMoE(CustomOp):
|
|||||||
self.global_num_experts = num_experts + num_redundant_experts
|
self.global_num_experts = num_experts + num_redundant_experts
|
||||||
|
|
||||||
# we padding globally so EP buffer allocation works
|
# we padding globally so EP buffer allocation works
|
||||||
if quant_config and quant_config.get_name() == "mxfp4":
|
if (quant_config and quant_config.get_name() == "mxfp4"
|
||||||
if not current_platform.is_device_capability(100):
|
and (current_platform.is_rocm()
|
||||||
if not is_torch_equal_or_newer("2.8.0"):
|
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||||
raise RuntimeError(
|
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)):
|
||||||
"Mxfp4 on non-blackwell requires torch >= 2.8.0")
|
hidden_size = round_up(hidden_size, 256)
|
||||||
if not has_triton_kernels():
|
|
||||||
raise NotImplementedError(
|
|
||||||
"triton_kernels must be installed for "
|
|
||||||
"mxfp4 on non-blackwell")
|
|
||||||
if (current_platform.is_rocm()
|
|
||||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
|
||||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
|
||||||
hidden_size = round_up(hidden_size, 256)
|
|
||||||
|
|
||||||
# For smuggling this layer into the fused moe custom op
|
# For smuggling this layer into the fused moe custom op
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from vllm.attention.backends.placeholder_attn import (
|
|||||||
PlaceholderAttentionMetadata)
|
PlaceholderAttentionMetadata)
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.mamba_attn import (
|
from vllm.v1.attention.backends.mamba2_attn import (
|
||||||
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
|
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
@ -19,7 +20,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
@ -55,6 +56,8 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
rms_norm_eps: float = 1e-5,
|
rms_norm_eps: float = 1e-5,
|
||||||
activation="silu",
|
activation="silu",
|
||||||
is_lora_enabled: bool = False,
|
is_lora_enabled: bool = False,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
prefix: str = ""):
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_step_rank = time_step_rank
|
self.time_step_rank = time_step_rank
|
||||||
@ -152,15 +155,42 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
# The inner tuple is (conv_state, ssm_state)
|
# The inner tuple is (conv_state, ssm_state)
|
||||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||||
|
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
|
def _ssm_transform(
|
||||||
|
self, x: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
if self.is_lora_enabled:
|
||||||
|
# Lora kernel requires contiguous tensor.
|
||||||
|
ssm_params = self.x_proj(x.contiguous())[0]
|
||||||
|
else:
|
||||||
|
ssm_params = self.x_proj(x)[0]
|
||||||
|
time_step, B, C = torch.split(
|
||||||
|
ssm_params,
|
||||||
|
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||||
|
dim=-1)
|
||||||
|
if self.use_rms_norm:
|
||||||
|
assert self.dt_layernorm is not None
|
||||||
|
assert self.b_layernorm is not None
|
||||||
|
assert self.c_layernorm is not None
|
||||||
|
time_step = self.dt_layernorm(time_step.contiguous())
|
||||||
|
B = self.b_layernorm(B.contiguous())
|
||||||
|
C = self.c_layernorm(C.contiguous())
|
||||||
|
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||||
|
return discrete_time_step, B, C
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
return CustomOp.forward(self, hidden_states, mamba_cache_params)
|
return CustomOp.forward(self, hidden_states, mamba_cache_params)
|
||||||
else:
|
else:
|
||||||
return self.forward_cuda(hidden_states, mamba_cache_params)
|
return self.forward_cuda(
|
||||||
|
hidden_states,
|
||||||
|
mamba_cache_params,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_native(self,
|
def forward_native(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -170,6 +200,27 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
def forward_cuda(self,
|
def forward_cuda(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||||
|
"""
|
||||||
|
Run the Mamba-1 SSM pipeline.
|
||||||
|
|
||||||
|
Steps
|
||||||
|
-----
|
||||||
|
1. Apply the gated-MLP linear projection to the raw input.
|
||||||
|
2. Pass the projected sequence through the convolutional mixing layer.
|
||||||
|
3. Feed the result into the State-Space Model (SSM) blocks.
|
||||||
|
4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||||
|
to produce contextual representations.
|
||||||
|
5. Project the contextualised sequence back
|
||||||
|
to the output embedding dimension.
|
||||||
|
|
||||||
|
Batch handling
|
||||||
|
--------------
|
||||||
|
Prefill and decode tokens are processed by dedicated CUDA
|
||||||
|
kernels for both the convolutional (conv1d) and SSM stages.
|
||||||
|
In the case of a mixed batch (containing both prefill and
|
||||||
|
decode tokens), both sets of kernels are executed independently
|
||||||
|
and their outputs are concatenated before the final output projection.
|
||||||
|
"""
|
||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
@ -185,126 +236,151 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
has_initial_state = mamba1_metadata.has_initial_states
|
has_initial_states = mamba1_metadata.has_initial_states
|
||||||
context_lens_tensor = mamba1_metadata.context_lens_tensor
|
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(attn_metadata, AttentionMetadata)
|
||||||
assert mamba_cache_params is not None
|
assert mamba_cache_params is not None
|
||||||
conv_state = mamba_cache_params.conv_state
|
conv_state = mamba_cache_params.conv_state
|
||||||
ssm_state = mamba_cache_params.ssm_state
|
ssm_state = mamba_cache_params.ssm_state
|
||||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||||
query_start_loc = attn_metadata.query_start_loc
|
query_start_loc = attn_metadata.query_start_loc
|
||||||
context_lens_tensor = attn_metadata.context_lens_tensor
|
context_lens_tensor = attn_metadata.context_lens_tensor
|
||||||
|
has_initial_states = None
|
||||||
if context_lens_tensor is not None:
|
if context_lens_tensor is not None:
|
||||||
has_initial_state = context_lens_tensor > 0
|
has_initial_states = context_lens_tensor > 0
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
|
||||||
|
|
||||||
# 2. Convolution sequence transformation
|
|
||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||||
self.conv1d.weight.size(2))
|
self.conv1d.weight.size(2))
|
||||||
|
|
||||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||||
# V1 profile run
|
# V1 profile run
|
||||||
hidden_states = hidden_states.contiguous()
|
hidden_states_BC = hidden_states_BC.contiguous()
|
||||||
return self.out_proj(hidden_states.transpose(-2, -1))[0]
|
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
|
||||||
|
|
||||||
if query_start_loc is not None and context_lens_tensor is not None:
|
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||||
# |---------- N-1 iteration --------|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
# |---------------- N iteration ---------------------|
|
num_prefills = attn_metadata.num_prefills # request count
|
||||||
# |- tokenA -|......................|-- newTokens ---|
|
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||||
# |---------- context_len ----------|
|
has_prefill = num_prefill_tokens > 0
|
||||||
# |-------------------- seq_len ---------------------|
|
has_decode = num_decode_tokens > 0
|
||||||
# |-- query_len ---|
|
|
||||||
hidden_states = causal_conv1d_fn(
|
prefill_decode_split = split_batch_to_prefill_and_decode(
|
||||||
hidden_states,
|
hidden_states_BC,
|
||||||
|
gate,
|
||||||
|
state_indices_tensor,
|
||||||
|
query_start_loc,
|
||||||
|
has_initial_states,
|
||||||
|
num_prefill_tokens,
|
||||||
|
num_decode_tokens,
|
||||||
|
num_prefills,
|
||||||
|
num_decodes,
|
||||||
|
)
|
||||||
|
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||||
|
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
||||||
|
gate_p = prefill_decode_split.gate_p
|
||||||
|
gate_d = prefill_decode_split.gate_d
|
||||||
|
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||||
|
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||||
|
query_start_loc_p = prefill_decode_split.query_start_loc_p
|
||||||
|
has_initial_states_p = prefill_decode_split.has_initial_states_p
|
||||||
|
|
||||||
|
ssm_outputs = []
|
||||||
|
|
||||||
|
if has_prefill:
|
||||||
|
# 2. Convolution sequence transformation
|
||||||
|
conv_out_p = causal_conv1d_fn(
|
||||||
|
hidden_states_BC_p,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
bias=self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
conv_states=conv_state,
|
conv_states=conv_state,
|
||||||
has_initial_state=has_initial_state,
|
has_initial_state=has_initial_states_p,
|
||||||
cache_indices=state_indices_tensor,
|
cache_indices=state_indices_tensor_p,
|
||||||
query_start_loc=query_start_loc)
|
query_start_loc=query_start_loc_p)
|
||||||
else:
|
# 3. State Space Model sequence transformations.
|
||||||
hidden_states = causal_conv1d_update(
|
discrete_time_step_p, B_p, C_p = self._ssm_transform(
|
||||||
hidden_states.transpose(0, 1),
|
conv_out_p.transpose(-2, -1))
|
||||||
|
time_proj_bias = self._time_proj_bias()
|
||||||
|
|
||||||
|
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||||
|
scan_out_p = selective_scan_fn(
|
||||||
|
conv_out_p,
|
||||||
|
ssm_state,
|
||||||
|
discrete_time_step_p,
|
||||||
|
self.A,
|
||||||
|
B_p.transpose(-2, -1),
|
||||||
|
C_p.transpose(-2, -1),
|
||||||
|
self.D.float(),
|
||||||
|
gate_p,
|
||||||
|
time_proj_bias,
|
||||||
|
delta_softplus=True,
|
||||||
|
cache_indices=state_indices_tensor_p,
|
||||||
|
has_initial_state=has_initial_states_p,
|
||||||
|
query_start_loc=query_start_loc_p)
|
||||||
|
ssm_outputs.append(scan_out_p)
|
||||||
|
|
||||||
|
if has_decode:
|
||||||
|
# 2. Convolution sequence transformation
|
||||||
|
conv_out_d = causal_conv1d_update(
|
||||||
|
hidden_states_BC_d.transpose(0, 1),
|
||||||
conv_state,
|
conv_state,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
self.activation,
|
self.activation,
|
||||||
conv_state_indices=state_indices_tensor)
|
conv_state_indices=state_indices_tensor_d).transpose(0, 1)
|
||||||
hidden_states = hidden_states.transpose(0, 1)
|
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation.
|
||||||
# 3.a. input varying initialization of time_step, B and C
|
discrete_time_step_d, B_d, C_d = self._ssm_transform(
|
||||||
|
conv_out_d.transpose(-2, -1))
|
||||||
|
time_proj_bias = self._time_proj_bias()
|
||||||
|
|
||||||
if self.is_lora_enabled:
|
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||||
# lora kernel requires contiguous tensor
|
scan_outputs_d = torch.empty_like(
|
||||||
ssm_parameters = self.x_proj(
|
hidden_states_BC_d.transpose(0, 1))
|
||||||
hidden_states.transpose(-2, -1).contiguous())[0]
|
|
||||||
else:
|
|
||||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
|
||||||
|
|
||||||
time_step, B, C = torch.split(
|
|
||||||
ssm_parameters,
|
|
||||||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
if self.use_rms_norm:
|
|
||||||
assert self.dt_layernorm is not None
|
|
||||||
assert self.b_layernorm is not None
|
|
||||||
assert self.c_layernorm is not None
|
|
||||||
time_step = self.dt_layernorm(time_step.contiguous())
|
|
||||||
B = self.b_layernorm(B.contiguous())
|
|
||||||
C = self.c_layernorm(C.contiguous())
|
|
||||||
|
|
||||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
|
||||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
|
||||||
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
|
||||||
self.dt_proj, "bias") else None)
|
|
||||||
|
|
||||||
if query_start_loc is not None and context_lens_tensor is not None:
|
|
||||||
scan_outputs = selective_scan_fn(
|
|
||||||
hidden_states,
|
|
||||||
ssm_state,
|
|
||||||
discrete_time_step,
|
|
||||||
self.A,
|
|
||||||
B.transpose(-2, -1),
|
|
||||||
C.transpose(-2, -1),
|
|
||||||
self.D.float(),
|
|
||||||
gate,
|
|
||||||
time_proj_bias,
|
|
||||||
delta_softplus=True,
|
|
||||||
cache_indices=state_indices_tensor,
|
|
||||||
has_initial_state=has_initial_state,
|
|
||||||
query_start_loc=query_start_loc)
|
|
||||||
else:
|
|
||||||
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
|
||||||
selective_state_update(ssm_state,
|
selective_state_update(ssm_state,
|
||||||
hidden_states.transpose(0, 1),
|
conv_out_d.transpose(0, 1),
|
||||||
discrete_time_step.transpose(0, 1),
|
discrete_time_step_d.transpose(0, 1),
|
||||||
self.A,
|
self.A,
|
||||||
B,
|
B_d,
|
||||||
C,
|
C_d,
|
||||||
self.D,
|
self.D,
|
||||||
gate.transpose(0, 1),
|
gate_d.transpose(0, 1),
|
||||||
time_proj_bias,
|
time_proj_bias,
|
||||||
dt_softplus=True,
|
dt_softplus=True,
|
||||||
state_batch_indices=state_indices_tensor,
|
state_batch_indices=state_indices_tensor_d,
|
||||||
out=scan_outputs)
|
out=scan_outputs_d)
|
||||||
scan_outputs = scan_outputs.transpose(0, 1)
|
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||||
|
|
||||||
# 4. Final linear projection
|
if envs.VLLM_USE_V1:
|
||||||
if self.is_lora_enabled:
|
ssm_outputs.insert(0, scan_outputs_d)
|
||||||
# lora kernel requires contiguous tensor
|
else:
|
||||||
contextualized_states = self.out_proj(
|
ssm_outputs.append(scan_outputs_d)
|
||||||
scan_outputs.transpose(-2, -1).contiguous())[0]
|
|
||||||
|
scan_outputs_combined = ssm_outputs[0] if len(
|
||||||
|
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||||
|
|
||||||
|
# 5. Final output projection
|
||||||
|
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
|
||||||
|
scan_outputs_combined = scan_outputs_combined.transpose(
|
||||||
|
-2, -1).contiguous()
|
||||||
|
out = self.out_proj(scan_outputs_combined)[0]
|
||||||
else:
|
else:
|
||||||
contextualized_states = self.out_proj(
|
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
|
||||||
scan_outputs.transpose(-2, -1))[0]
|
|
||||||
return contextualized_states
|
return out
|
||||||
|
|
||||||
|
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||||
|
assert self.model_config is not None
|
||||||
|
assert self.cache_config is not None
|
||||||
|
return MambaStateDtypeCalculator.mamba1_state_dtype(
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.cache_config.mamba_cache_dtype,
|
||||||
|
self.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||||
@ -317,3 +393,69 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
@property
|
@property
|
||||||
def mamba_type(self) -> str:
|
def mamba_type(self) -> str:
|
||||||
return "mamba1"
|
return "mamba1"
|
||||||
|
|
||||||
|
def _time_proj_bias(self) -> Optional[torch.Tensor]:
|
||||||
|
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||||
|
return self.dt_proj.bias.float()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PrefillDecodeSplit(NamedTuple):
|
||||||
|
hidden_states_BC_p: torch.Tensor
|
||||||
|
hidden_states_BC_d: torch.Tensor
|
||||||
|
gate_p: torch.Tensor
|
||||||
|
gate_d: torch.Tensor
|
||||||
|
state_indices_tensor_p: torch.Tensor
|
||||||
|
state_indices_tensor_d: torch.Tensor
|
||||||
|
query_start_loc_p: Optional[torch.Tensor]
|
||||||
|
has_initial_states_p: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def split_batch_to_prefill_and_decode(
|
||||||
|
hidden_states_BC: torch.Tensor,
|
||||||
|
gate: torch.Tensor,
|
||||||
|
state_indices_tensor: torch.Tensor,
|
||||||
|
query_start_loc: torch.Tensor,
|
||||||
|
has_initial_states: Optional[torch.Tensor],
|
||||||
|
num_prefill_tokens: int,
|
||||||
|
num_decode_tokens: int,
|
||||||
|
num_prefills: int,
|
||||||
|
num_decodes: int,
|
||||||
|
) -> PrefillDecodeSplit:
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
# In v1, decode tokens come first, then prefill tokens.
|
||||||
|
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||||
|
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
|
||||||
|
gate_d, gate_p = torch.split(gate,
|
||||||
|
[num_decode_tokens, num_prefill_tokens],
|
||||||
|
dim=-1)
|
||||||
|
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||||
|
state_indices_tensor, [num_decodes, num_prefills], dim=0)
|
||||||
|
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||||
|
num_decodes if num_prefills > 0 else None)
|
||||||
|
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||||
|
has_initial_states is not None and num_prefills > 0) else None
|
||||||
|
else:
|
||||||
|
# In v0, prefill tokens come first, then decode tokens.
|
||||||
|
hidden_states_BC_p, hidden_states_BC_d = torch.split(
|
||||||
|
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
|
||||||
|
gate_p, gate_d = torch.split(gate,
|
||||||
|
[num_prefill_tokens, num_decode_tokens],
|
||||||
|
dim=-1)
|
||||||
|
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||||
|
state_indices_tensor, [num_prefills, num_decodes], dim=0)
|
||||||
|
query_start_loc_p = (query_start_loc[:num_prefills +
|
||||||
|
1] if num_prefills > 0 else None)
|
||||||
|
has_initial_states_p = has_initial_states[:num_prefills] if (
|
||||||
|
has_initial_states is not None and num_prefills > 0) else None
|
||||||
|
|
||||||
|
return PrefillDecodeSplit(
|
||||||
|
hidden_states_BC_p=hidden_states_BC_p,
|
||||||
|
hidden_states_BC_d=hidden_states_BC_d,
|
||||||
|
gate_p=gate_p,
|
||||||
|
gate_d=gate_d,
|
||||||
|
state_indices_tensor_p=state_indices_tensor_p,
|
||||||
|
state_indices_tensor_d=state_indices_tensor_d,
|
||||||
|
query_start_loc_p=query_start_loc_p,
|
||||||
|
has_initial_states_p=has_initial_states_p,
|
||||||
|
)
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
|
|||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||||
update_metadata)
|
update_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||||
@ -36,7 +36,7 @@ from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
|
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||||
|
|
||||||
# Added by the IBM Team, 2024
|
# Added by the IBM Team, 2024
|
||||||
|
|
||||||
@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
**selective** state spaces)
|
**selective** state spaces)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
hidden_size: int,
|
||||||
hidden_size: int,
|
ssm_state_size: int,
|
||||||
ssm_state_size: int,
|
conv_kernel_size: int,
|
||||||
conv_kernel_size: int,
|
intermediate_size: int,
|
||||||
intermediate_size: int,
|
use_conv_bias: bool,
|
||||||
use_conv_bias: bool,
|
use_bias: bool,
|
||||||
use_bias: bool,
|
n_groups: int = 1,
|
||||||
n_groups: int = 1,
|
num_heads: int = 128,
|
||||||
num_heads: int = 128,
|
head_dim: int = 64,
|
||||||
head_dim: int = 64,
|
rms_norm_eps: float = 1e-5,
|
||||||
rms_norm_eps: float = 1e-5,
|
activation: str = "silu",
|
||||||
activation: str = "silu",
|
use_rms_norm: bool = True,
|
||||||
use_rms_norm: bool = True,
|
model_config: Optional[ModelConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
prefix: str = "",
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# For TP, the sharding plan is as follows:
|
# For TP, the sharding plan is as follows:
|
||||||
@ -417,6 +417,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
# The inner tuple is (conv_state, ssm_state)
|
# The inner tuple is (conv_state, ssm_state)
|
||||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||||
|
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
@ -670,7 +672,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
)
|
state_dtype=ssm_state.dtype)
|
||||||
|
|
||||||
# update ssm states
|
# update ssm states
|
||||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||||
@ -732,6 +734,15 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
# 5. Final linear projection
|
# 5. Final linear projection
|
||||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||||
|
|
||||||
|
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||||
|
assert self.model_config is not None
|
||||||
|
assert self.cache_config is not None
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.cache_config.mamba_cache_dtype,
|
||||||
|
self.cache_config.mamba_ssm_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||||
intermediate_size=self.intermediate_size,
|
intermediate_size=self.intermediate_size,
|
||||||
|
|||||||
@ -1,6 +1,58 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import MambaDType, ModelDType
|
||||||
from vllm.distributed import divide
|
from vllm.distributed import divide
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class MambaStateDtypeCalculator:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def linear_attention_state_dtype(
|
||||||
|
cls,
|
||||||
|
model_dtype: Union[ModelDType, torch.dtype],
|
||||||
|
mamba_cache_dtype: MambaDType,
|
||||||
|
) -> tuple[torch.dtype, ...]:
|
||||||
|
# TODO (tdoublep) requires testing
|
||||||
|
if mamba_cache_dtype == "float32":
|
||||||
|
raise ValueError("fp32 state for minimax is not yet supported")
|
||||||
|
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||||
|
return (state_dtype, )
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def mamba1_state_dtype(
|
||||||
|
cls,
|
||||||
|
model_dtype: Union[ModelDType, torch.dtype],
|
||||||
|
mamba_cache_dtype: MambaDType,
|
||||||
|
mamba_ssm_cache_dtype: MambaDType,
|
||||||
|
) -> tuple[torch.dtype, ...]:
|
||||||
|
# TODO (tdoublep) requires kernel changes
|
||||||
|
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
|
||||||
|
raise ValueError("fp32 state for mamba1 is not yet supported")
|
||||||
|
else:
|
||||||
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||||
|
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def mamba2_state_dtype(
|
||||||
|
cls,
|
||||||
|
model_dtype: Union[ModelDType, torch.dtype],
|
||||||
|
mamba_cache_dtype: MambaDType,
|
||||||
|
mamba_ssm_cache_dtype: MambaDType,
|
||||||
|
) -> tuple[torch.dtype, ...]:
|
||||||
|
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
|
||||||
|
model_dtype)
|
||||||
|
if mamba_ssm_cache_dtype == "auto":
|
||||||
|
temporal_state_dtype = conv_state_dtype
|
||||||
|
else:
|
||||||
|
temporal_state_dtype = (
|
||||||
|
STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])
|
||||||
|
|
||||||
|
return (conv_state_dtype, temporal_state_dtype)
|
||||||
|
|
||||||
|
|
||||||
class MambaStateShapeCalculator:
|
class MambaStateShapeCalculator:
|
||||||
|
|||||||
@ -41,6 +41,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
cu_seqlens=None,
|
cu_seqlens=None,
|
||||||
dt_softplus=False,
|
dt_softplus=False,
|
||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
|
state_dtype=None,
|
||||||
out=None):
|
out=None):
|
||||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||||
batch, seqlen, nheads, headdim = x.shape
|
batch, seqlen, nheads, headdim = x.shape
|
||||||
@ -118,7 +119,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
|||||||
if initial_states is not None else None,
|
if initial_states is not None else None,
|
||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
out_dtype=C.dtype,
|
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||||
is_cont_batched=cu_seqlens is not None)
|
is_cont_batched=cu_seqlens is not None)
|
||||||
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
|
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
|
||||||
for t in [states, final_states])
|
for t in [states, final_states])
|
||||||
@ -189,7 +190,8 @@ def mamba_chunk_scan_combined(x,
|
|||||||
dt_limit=(0.0, float("inf")),
|
dt_limit=(0.0, float("inf")),
|
||||||
out=None,
|
out=None,
|
||||||
return_final_states=False,
|
return_final_states=False,
|
||||||
return_varlen_states=False):
|
return_varlen_states=False,
|
||||||
|
state_dtype=None):
|
||||||
"""
|
"""
|
||||||
Argument:
|
Argument:
|
||||||
x: (batch, seqlen, nheads, headdim)
|
x: (batch, seqlen, nheads, headdim)
|
||||||
@ -206,6 +208,7 @@ def mamba_chunk_scan_combined(x,
|
|||||||
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
||||||
dt_softplus: Whether to apply softplus to dt
|
dt_softplus: Whether to apply softplus to dt
|
||||||
out: Preallocated output tensor
|
out: Preallocated output tensor
|
||||||
|
state_dtype: The data type of the ssm state
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not return_varlen_states:
|
if not return_varlen_states:
|
||||||
@ -229,7 +232,8 @@ def mamba_chunk_scan_combined(x,
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
dt_softplus=dt_softplus,
|
dt_softplus=dt_softplus,
|
||||||
dt_limit=dt_limit,
|
dt_limit=dt_limit,
|
||||||
out=out)
|
out=out,
|
||||||
|
state_dtype=state_dtype)
|
||||||
if not return_varlen_states:
|
if not return_varlen_states:
|
||||||
if not return_final_states:
|
if not return_final_states:
|
||||||
return
|
return
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user