mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 11:17:11 +08:00
wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
c2d1b075ba
commit
1aaced5830
@ -193,6 +193,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/layernorm_quant_kernels.cu"
|
||||
"csrc/cuda_view.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
@ -200,6 +201,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/prepare_inputs/copy_subranges.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
@ -47,3 +47,11 @@
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
||||
// #ifndef USE_ROCM
|
||||
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
|
||||
// cudaHostGetDevicePointer(device_ptr, host_ptr, flags)
|
||||
// #else
|
||||
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
|
||||
// hipHostGetDevicePointer(device_ptr, host_ptr, flags)
|
||||
// #endif
|
||||
|
||||
43
csrc/cuda_view.cu
Normal file
43
csrc/cuda_view.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/cuda.h>
|
||||
|
||||
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
|
||||
// memory, and that UVA (Unified Virtual Addressing) is enabled.
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
|
||||
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
||||
TORCH_CHECK(cpu_tensor.is_contiguous(), "Input tensor must be contiguous");
|
||||
|
||||
// Get raw host pointer from CPU tensor
|
||||
void* host_ptr = cpu_tensor.data_ptr();
|
||||
|
||||
// Get a device pointer corresponding to the pinned host memory
|
||||
void* device_ptr = nullptr;
|
||||
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
TORCH_CHECK(err == cudaSuccess,
|
||||
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||
|
||||
// Construct a CUDA tensor from the device pointer.
|
||||
// We'll use the same sizes, strides, and dtype as the CPU tensor.
|
||||
auto sizes = cpu_tensor.sizes();
|
||||
auto strides = cpu_tensor.strides();
|
||||
auto options =
|
||||
cpu_tensor.options().device(torch::kCUDA); // Change device to CUDA
|
||||
|
||||
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
|
||||
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
|
||||
// memory, so we don't free it here.
|
||||
auto deleter = [](void*) {
|
||||
// no-op, since the memory is owned by the original CPU tensor
|
||||
};
|
||||
|
||||
torch::Tensor cuda_tensor =
|
||||
torch::from_blob(device_ptr, sizes, strides, deleter, options);
|
||||
|
||||
TORCH_CHECK(cuda_tensor.device().is_cuda(),
|
||||
"Resulting tensor is not on CUDA device");
|
||||
TORCH_CHECK(cuda_tensor.sizes().equals(sizes), "Size mismatch");
|
||||
TORCH_CHECK(cuda_tensor.strides().equals(strides), "Stride mismatch");
|
||||
TORCH_CHECK(cuda_tensor.dtype() == cpu_tensor.dtype(), "Dtype mismatch");
|
||||
|
||||
return cuda_tensor;
|
||||
}
|
||||
@ -115,6 +115,11 @@ void advance_step_flashinfer(
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
|
||||
torch::Tensor& matrix_tgt, int64_t n);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
|
||||
63
csrc/prepare_inputs/copy_subranges.cu
Normal file
63
csrc/prepare_inputs/copy_subranges.cu
Normal file
@ -0,0 +1,63 @@
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
namespace vllm {
|
||||
__global__ void copy_subranges_kernel(const int* __restrict__ matrix_src,
|
||||
const int* __restrict__ matrix_diff,
|
||||
int* __restrict__ matrix_tgt, int64_t M) {
|
||||
int row_id = blockIdx.x;
|
||||
int row_offset = row_id * M;
|
||||
|
||||
int start = matrix_diff[row_id * 2];
|
||||
int length = matrix_diff[row_id * 2 + 1];
|
||||
int end = start + length;
|
||||
int thread_idx = threadIdx.x;
|
||||
for (int i = start + thread_idx; i < end; i += blockDim.x) {
|
||||
int idx = row_offset + i;
|
||||
matrix_tgt[idx] = matrix_src[idx];
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
|
||||
torch::Tensor& matrix_tgt, int64_t n) {
|
||||
// Check tensor properties
|
||||
TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
|
||||
TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
|
||||
TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
|
||||
TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
|
||||
TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
|
||||
TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");
|
||||
|
||||
auto src_sizes = matrix_src.sizes();
|
||||
auto diff_sizes = matrix_diff.sizes();
|
||||
auto tgt_sizes = matrix_tgt.sizes();
|
||||
|
||||
TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
|
||||
TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
|
||||
TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");
|
||||
|
||||
int64_t N = src_sizes[0];
|
||||
int64_t M = src_sizes[1];
|
||||
|
||||
TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
|
||||
TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
|
||||
TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
|
||||
"matrix_tgt must have same shape as matrix_src");
|
||||
|
||||
TORCH_CHECK(n <= N, "n must be <= N");
|
||||
|
||||
const int* d_matrix_src = matrix_src.data_ptr<int>();
|
||||
const int* d_matrix_diff = matrix_diff.data_ptr<int>();
|
||||
int* d_matrix_tgt = matrix_tgt.data_ptr<int>();
|
||||
|
||||
// One thread block per row.
|
||||
int blocks = n;
|
||||
int threads = 1024;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(matrix_tgt));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::copy_subranges_kernel<<<blocks, threads, 0, stream>>>(
|
||||
d_matrix_src, d_matrix_diff, d_matrix_tgt, M);
|
||||
}
|
||||
@ -21,6 +21,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
|
||||
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
||||
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
|
||||
&get_cuda_view_from_cpu_tensor);
|
||||
|
||||
// Attention ops
|
||||
// Compute the attention between an input query and the cached
|
||||
// keys/values using PagedAttention.
|
||||
@ -98,6 +102,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
") -> ()");
|
||||
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
||||
|
||||
ops.def(
|
||||
"copy_subranges(Tensor matrix_src, Tensor matrix_diff, Tensor! "
|
||||
"matrix_tgt, "
|
||||
"int n) -> ()");
|
||||
ops.impl("copy_subranges", torch::kCUDA, ©_subranges);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
|
||||
@ -249,6 +249,17 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
|
||||
block_table_bound)
|
||||
|
||||
|
||||
# copy subrange op. Used for input preparation in the vLLM V1 GPU backend.
|
||||
def copy_subranges(
|
||||
src_matrix: torch.Tensor,
|
||||
diff_matrix: torch.Tensor,
|
||||
tgt_matrix: torch.Tensor,
|
||||
num_subranges: int,
|
||||
) -> None:
|
||||
torch.ops._C.copy_subranges(src_matrix, diff_matrix, tgt_matrix,
|
||||
num_subranges)
|
||||
|
||||
|
||||
# fused quant layer norm ops
|
||||
def rms_norm_dynamic_per_token_quant(
|
||||
input: torch.Tensor,
|
||||
|
||||
@ -1523,6 +1523,13 @@ def weak_ref_tensors(
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
|
||||
"""
|
||||
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
|
||||
|
||||
def is_in_doc_build() -> bool:
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
|
||||
116
vllm/v1/worker/gpu_block_table.py
Normal file
116
vllm/v1/worker/gpu_block_table.py
Normal file
@ -0,0 +1,116 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_cuda_view_from_cpu_tensor
|
||||
|
||||
|
||||
class BlockTable:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_blocks_per_req: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
self.block_table = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
|
||||
# Pinned memory is required to use UVA.
|
||||
# TODO(woosuk): Add other requirements for UVA.
|
||||
self.use_uva = pin_memory
|
||||
if self.use_uva:
|
||||
self.block_table_diff = torch.zeros((max_num_reqs, 2),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.block_table_diff_np = self.block_table_diff.numpy()
|
||||
|
||||
self.block_table_cpu_cuda_view = get_cuda_view_from_cpu_tensor(
|
||||
self.block_table_cpu)
|
||||
self.block_table_diff_cuda_view = get_cuda_view_from_cpu_tensor(
|
||||
self.block_table_diff)
|
||||
|
||||
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
|
||||
num_blocks = len(block_ids)
|
||||
self.block_table_np[row_idx, :num_blocks] = block_ids
|
||||
if self.use_uva:
|
||||
self.block_table_diff_np[row_idx, 0] = 0
|
||||
self.block_table_diff_np[row_idx, 1] = num_blocks
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
row_idx: int,
|
||||
start: int,
|
||||
block_ids: List[int],
|
||||
) -> None:
|
||||
num_blocks = len(block_ids)
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
if self.use_uva:
|
||||
self.block_table_diff_np[row_idx, 0] = start
|
||||
# Move-and-append is not allowed.
|
||||
assert self.block_table_diff_np[row_idx, 1] == 0
|
||||
self.block_table_diff_np[row_idx, 1] = num_blocks
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
self.block_table_np[tgt] = self.block_table_np[src]
|
||||
if self.use_uva:
|
||||
# Append-and-move is allowed.
|
||||
self.block_table_diff_np[tgt] = self.block_table_diff_np[src]
|
||||
# Clear the source row.
|
||||
self.block_table_diff_np[src].fill(0)
|
||||
|
||||
def apply_diff(self, num_reqs: int) -> None:
|
||||
if self.use_uva:
|
||||
# Only copy the diff to the GPU.
|
||||
ops.copy_subranges(
|
||||
self.block_table_cpu_cuda_view,
|
||||
self.block_table_diff_cuda_view,
|
||||
self.block_table,
|
||||
num_reqs,
|
||||
)
|
||||
else:
|
||||
# Copy the entire block table to the GPU.
|
||||
# NOTE(woosuk): This can be a performance bottleneck when the block
|
||||
# table is large.
|
||||
self.block_table[:num_reqs].copy_(
|
||||
self.block_table_cpu[:num_reqs], non_blocking=True)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.fill_(0)
|
||||
self.block_table_cpu.fill_(0)
|
||||
if self.use_uva:
|
||||
self.block_table_diff.fill_(0)
|
||||
|
||||
def clear_diff(self) -> None:
|
||||
if self.use_uva:
|
||||
self.block_table_diff_np.fill(0)
|
||||
|
||||
def cuda(self) -> torch.Tensor:
|
||||
return self.block_table
|
||||
|
||||
def cpu(self) -> torch.Tensor:
|
||||
return self.block_table_cpu
|
||||
|
||||
def numpy(self) -> np.ndarray:
|
||||
return self.block_table_np
|
||||
@ -9,6 +9,7 @@ import torch
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
@ -64,19 +65,14 @@ class InputBatch:
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Attention-related.
|
||||
self.block_table = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
# Block table.
|
||||
self.block_table = BlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
)
|
||||
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
@ -141,8 +137,7 @@ class InputBatch:
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
num_blocks = len(request.block_ids)
|
||||
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
|
||||
self.block_table.add_row(req_index, request.block_ids)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
@ -221,13 +216,12 @@ class InputBatch:
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
# TODO(woosuk): Optimize the copy of token_ids_cpu and
|
||||
# block_table_cpu.
|
||||
# block_table.
|
||||
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table_cpu[empty_index] = self.block_table_cpu[
|
||||
last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
|
||||
@ -162,6 +162,9 @@ class GPUModelRunner:
|
||||
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Clean up diffs.
|
||||
self.input_batch.block_table.clear_diff()
|
||||
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the pre-empted requests.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
@ -203,10 +206,9 @@ class GPUModelRunner:
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
start_index = len(req_state.block_ids)
|
||||
end_index = start_index + num_new_blocks
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
self.input_batch.block_table_cpu[
|
||||
req_index, start_index:end_index] = req_data.new_block_ids
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
@ -267,9 +269,7 @@ class GPUModelRunner:
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table[:num_reqs].copy_(
|
||||
self.input_batch.block_table_cpu_tensor[:num_reqs],
|
||||
non_blocking=True)
|
||||
self.input_batch.block_table.apply_diff(num_reqs)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
@ -325,7 +325,7 @@ class GPUModelRunner:
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
block_numbers = (self.input_batch.block_table_cpu_tensor.flatten()
|
||||
block_numbers = (self.input_batch.block_table.cpu().flatten()
|
||||
[block_table_indices].numpy())
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
@ -360,7 +360,7 @@ class GPUModelRunner:
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_start_loc=seq_start_loc,
|
||||
block_table=self.input_batch.block_table[:num_reqs],
|
||||
block_table=self.input_batch.block_table.cuda()[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user