diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d2324d7cee60f..c4fc43dc0abb8 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -230,14 +230,12 @@ steps: commands: - pytest -v -s compile/test_basic_correctness.py -# TODO: re-write in comparison tests, and fix symbolic shape -# for quantization ops. -# - label: "PyTorch Fullgraph Test" # 18min -# source_file_dependencies: -# - vllm/ -# - tests/compile -# commands: -# - pytest -v -s compile/test_full_graph.py +- label: "PyTorch Fullgraph Test" # 18min + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_full_graph.py - label: Kernels Test %N # 1h each mirror_hardwares: [amd] diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f4648a37dbca..7f6d1c66b2cf7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,24 +83,6 @@ endif() # find_package(Torch REQUIRED) -# -message(STATUS "Enabling core extension.") - -# Define _core_C extension -# built for (almost) every target platform, (excludes TPU and Neuron) - -set(VLLM_EXT_SRC - "csrc/core/torch_bindings.cpp") - -define_gpu_extension_target( - _core_C - DESTINATION vllm - LANGUAGE CXX - SOURCES ${VLLM_EXT_SRC} - COMPILE_FLAGS ${CXX_COMPILE_FLAGS} - USE_SABI 3 - WITH_SOABI) - # # Forward the non-CUDA device extensions to external CMake scripts. # diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index 0e1f360d74bd5..408e736d5bc0f 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -1,6 +1,7 @@ #pragma once -#include +// For TORCH_CHECK +#include namespace vllm { @@ -9,12 +10,7 @@ namespace vllm { // in particular it can be used to represent sub-byte data types (something // that torch.dtype currently does not support). // -// ScalarTypeTorch is a subclass of ScalarType that is compatible with -// TORCH_LIBRARY, making it accessible from Python as well meaning this class -// can be used as a argument for custom operators, helping to simplify these -// interfaces. -// -// The type definitions on the Python side can be found in: vllm/_core_ext.pyi +// The type definitions on the Python side can be found in: vllm/scalar_type.py // these type definitions should be kept up to date with any Python API changes // here. // @@ -308,204 +304,7 @@ class ScalarType { } }; -// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from -// torch::CustomClassHolder), we use multiple inheritance here since we cannot -// have ScalarType inherit from torch::CustomClassHolder and have a constexpr -// constructor at the same time (torch::CustomClassHolder does not have a -// constexpr destructor) -// See also: -// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA -class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { - public: - ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias, - bool _signed) - : ScalarType(exponent, mantissa, bias, _signed){}; - - ScalarTypeTorch(ScalarType type) : ScalarType(type){}; - - using Base = ScalarType; - using Self = ScalarTypeTorch; - using SelfPtr = c10::intrusive_ptr; - - static void check_size_bits(int64_t size_bits, bool signed_) { - TORCH_CHECK( - size_bits <= - std::numeric_limits().mantissa)>::max(), - "size_bits bit width is too large to be represented"); - } - - static void check_bias(int64_t bias) { - using Bias = decltype(std::declval().bias); - TORCH_CHECK(bias <= std::numeric_limits::max() && - bias >= std::numeric_limits::min(), - "bias too large or small to be represented"); - } - - static void check_exponent(int64_t exponent) { - TORCH_CHECK( - exponent <= - std::numeric_limits().exponent)>::max(), - "exponent bit width is too large to be represented"); - } - - static void check_mantissa(int64_t mantissa) { - TORCH_CHECK( - mantissa <= - std::numeric_limits().mantissa)>::max(), - "mantissa bit width is too large to be represented"); - } - - static SelfPtr int_(int64_t size_bits, c10::optional bias) { - check_size_bits(size_bits, true); - check_bias(bias.value_or(0)); - return c10::make_intrusive( - ScalarType::int_(size_bits, bias.value_or(0))); - } - - static SelfPtr uint(int64_t size_bits, c10::optional bias) { - check_size_bits(size_bits, true); - check_bias(bias.value_or(0)); - return c10::make_intrusive( - ScalarType::uint(size_bits, bias.value_or(0))); - } - - static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) { - check_mantissa(mantissa); - check_exponent(exponent); - return c10::make_intrusive( - ScalarType::float_IEEE754(exponent, mantissa)); - } - - static SelfPtr float_(int64_t exponent, int64_t mantissa, - bool finite_values_only, int64_t nan_repr) { - check_mantissa(mantissa); - check_exponent(exponent); - return c10::make_intrusive(ScalarType::float_( - exponent, mantissa, finite_values_only, NanRepr(nan_repr))); - } - - // This needs to be implemented and throw a TypeError in order for - // PyTorch's opcheck to work on ops that use ScalarTypes. - int64_t len() const { - throw c10::TypeError({__func__, __FILE__, static_cast(__LINE__)}, - "__len__ not implemented"); - return 0; - } - - // Serialize a ScalarType into a tuple of pairs. Where each pair - // is a (fieldname, value). - // For simplicity, we are just going to convert to a ScalarTypeId. - std::tuple> obj_flatten() const { - return {{"ScalarType", id()}}; - } - - // Deserialize a scalar type that has been serialized by obj_flatten, - // ostensibly from a tuple of (member name, value) pairs, but in reality - // just a ScalarTypeId. - static SelfPtr obj_unflatten( - std::tuple> const& flat_type) { - return c10::make_intrusive( - from_id(std::get<1>(std::get<0>(flat_type)))); - } - - template - static void bind_readonly_property(torch::class_& cls, - std::string const& name, T Base::*field) { - auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) { - if constexpr (std::is_member_function_pointer_v) { - return (self.get()->*field)(); - } else { - return self.get()->*field; - } - }; - - auto getter_func = [field = std::move(field), - getter_func_helper = std::move(getter_func_helper)]( - SelfPtr const& self) { - auto val = getter_func_helper(self); - // upconvert uint8_t, int32_t etc. to int64_t for python - if constexpr (std::is_integral_v) { - return static_cast(val); - } else { - return val; - } - }; - - cls.def_property(name, getter_func); - } - - template - static void bind_function(torch::class_& cls, const std::string& name, - MemberFunc Cls::*member) { - cls.def(name, [member = std::move(member)](SelfPtr const& self) { - return (self.get()->*member)(); - }); - } - - template - static void bind_function(torch::class_& cls, const std::string& name, - Func func) { - cls.def(name, func); - } - - template - static void bind_static_function(torch::class_& cls, - const std::string& name, Func func) { - cls.def_static(name, func); - } - - static void bind_class(torch::Library& lib) { - auto cls = lib.class_("ScalarType") - .def(torch::init()); - - // Bind Properties - bind_readonly_property(cls, "mantissa", &Base::mantissa); - bind_readonly_property(cls, "exponent", &Base::exponent); - bind_readonly_property(cls, "bias", &Base::bias); - bind_readonly_property(cls, "signed", &Base::is_signed); - bind_readonly_property(cls, "size_bits", &Base::size_bits); - - // Bind member functions - bind_function(cls, "is_signed", &Base::is_signed); - bind_function(cls, "is_integer", &Base::is_integer); - bind_function(cls, "is_floating_point", &Base::is_floating_point); - bind_function(cls, "is_ieee_754", &Base::is_ieee_754); - bind_function(cls, "has_nans", &Base::has_nans); - bind_function(cls, "has_infs", &Base::has_infs); - bind_function(cls, "has_bias", &Base::has_bias); - - bind_function(cls, "max", [](SelfPtr const& self) { - return std::visit([](auto arg) { return c10::IValue(arg); }, - self.get()->max()); - }); - bind_function(cls, "min", [](SelfPtr const& self) { - return std::visit([](auto arg) { return c10::IValue(arg); }, - self.get()->min()); - }); - - bind_function(cls, "__len__", &ScalarTypeTorch::len); - bind_function(cls, "__str__", &Base::str); - bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) { - return *self == *other; - }); - bind_function(cls, "__repr__", [](SelfPtr const& self) { - return "ScalarType." + self.get()->str(); - }); - - bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten); - bind_static_function(cls, "__obj_unflatten__", - &ScalarTypeTorch::obj_unflatten); - - // Bind static functions (convenience constructors) - bind_static_function(cls, "int_", &ScalarTypeTorch::int_); - bind_static_function(cls, "uint", &ScalarTypeTorch::uint); - bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754); - bind_static_function(cls, "float_", &ScalarTypeTorch::float_); - } -}; - -using ScalarTypeId = int64_t; -using ScalarTypeTorchPtr = c10::intrusive_ptr; +using ScalarTypeId = ScalarType::Id; // "rust style" names generally following: // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 diff --git a/csrc/core/torch_bindings.cpp b/csrc/core/torch_bindings.cpp deleted file mode 100644 index f60254189a2f7..0000000000000 --- a/csrc/core/torch_bindings.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -#include "scalar_type.hpp" -#include "registration.h" - -// Note the CORE exstension will be built for (almost) all hardware targets so -// new additions must account for this. (currently not built for TPU and Neuron) - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) { - // ScalarType, a custom class for representing data types that supports - // quantized types, declared here so it can be used when creating interfaces - // for custom ops. - vllm::ScalarTypeTorch::bind_class(lib); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index e2db4e4196b6f..5f12483e951e8 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -484,21 +484,22 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& topk_ids, const torch::Tensor& b_scales, torch::Tensor& b_zeros, const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { + vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); bool has_zp = b_zeros.size(1) != 0; if (has_zp) { TORCH_CHECK( - *b_q_type == vllm::kU4, - "b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str()); + b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK( - *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str()); } - int pack_factor = 32 / b_q_type->size_bits(); + int pack_factor = 32 / b_q_type.size_bits(); int max_par = 4; @@ -575,7 +576,7 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, + b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, num_experts, topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 18fbc57ac7834..019c6cedd3d80 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -13,8 +13,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " - "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " - "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " + "int b_q_type, SymInt size_m, " + "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int " + "topk, " "int moe_block_size, bool replicate_input, bool apply_weights)" " -> Tensor"); // conditionally compiled so impl registration is in source file diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 5efe15d2b2f6b..6dbf9594e8492 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -80,7 +80,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, + vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp) { TORCH_CHECK_NOT_IMPLEMENTED(false, @@ -2132,22 +2132,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, bool use_fp32_reduce) { + vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); if (has_zp) { - TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", - b_q_type->str()); + TORCH_CHECK( + b_q_type == vllm::kU4 || b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK( - *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", - b_q_type->str()); + b_q_type.str()); } - int pack_factor = 32 / b_q_type->size_bits(); + int pack_factor = 32 / b_q_type.size_bits(); // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), @@ -2279,7 +2280,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, + workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else if (a.scalar_type() == at::ScalarType::BFloat16) { @@ -2288,7 +2289,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, c.data_ptr(), c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, + workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else { @@ -2302,4 +2303,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("gptq_marlin_gemm", &gptq_marlin_gemm); -} \ No newline at end of file +} diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index ff037756f55ab..9f9073ded6191 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -38,9 +38,10 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) { // Interface // -std::vector supported_schedules(ScalarTypeTorchPtr const& btype) { +std::vector supported_schedules(ScalarTypeId const btype_id) { #if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 - return scalar_type_dispatch(*btype, [&](auto BType) { + vllm::ScalarType b_type = ScalarType::from_id(btype_id); + return scalar_type_dispatch(b_type, [&](auto BType) { return GemmDispatcher::supported_schedules(); }); #else @@ -49,7 +50,7 @@ std::vector supported_schedules(ScalarTypeTorchPtr const& btype) { } torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, - ScalarTypeTorchPtr const& btype, + ScalarTypeId const btype_id, c10::optional const& scales, c10::optional const& zeros, c10::optional group_size, @@ -57,6 +58,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, c10::optional alpha, c10::optional beta, c10::optional schedule) { #if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 + ScalarType const btype = ScalarType::from_id(btype_id); auto args = PyTorchArguments{.A = A, .B = B, .scales = scales, @@ -67,7 +69,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, .beta = beta, .schedule = schedule}; - return scalar_type_dispatch(*btype, [&](auto BType) { + return scalar_type_dispatch(btype, [&](auto BType) { return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES( A.scalar_type(), "machete_gemm", [&] { using ComputeType = equivalent_cutlass_type_t; @@ -79,9 +81,9 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, #endif } -torch::Tensor prepack_B(torch::Tensor const& B, - vllm::ScalarTypeTorchPtr const& btype) { - return scalar_type_dispatch(*btype, [&](auto BType) { +torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) { + ScalarType const btype = ScalarType::from_id(btype_id); + return scalar_type_dispatch(btype, [&](auto BType) { return PrepackBDispatcher::dispatch(B); }); } diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 908e4f70ab1e6..a33e2660d760e 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -89,7 +89,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, + vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -1029,13 +1029,14 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, + vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k) { + vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); // Verify num_bits - TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, - "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str()); - int pack_factor = 32 / b_q_type->size_bits(); + TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, + "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str()); + int pack_factor = 32 / b_q_type.size_bits(); // Verify M TORCH_CHECK(size_m == a.size(0), @@ -1130,8 +1131,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, marlin_24::marlin_cuda_2_4( a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), - b_q_type->size_bits(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par); + b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_m, sms, max_par); return c; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d69c4e5afb4a7..b999028fe06a9 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -140,13 +140,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantized GEMM for AWQ. ops.def( "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, int split_k_iters) -> Tensor"); + "Tensor _zeros, SymInt split_k_iters) -> Tensor"); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); // Dequantization for AWQ. ops.def( "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor"); + "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor"); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); // Note about marlin kernel 'workspace' arguments: @@ -166,32 +166,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Marlin (Dense) Optimized Quantized GEMM for GPTQ. ops.def( "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor"); + "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> " + "Tensor"); // conditionally compiled so impl in source file // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " "Tensor b_scales, Tensor workspace, " - "__torch__.torch.classes._core_C.ScalarType b_q_type, " - "int size_m, int size_n, int size_k) -> Tensor"); + "int b_q_type, " + "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor"); // conditionally compiled so impl in source file // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. + ops.def("machete_supported_schedules(int btype) -> str[]"); ops.def( - "machete_supported_schedules(" - " __torch__.torch.classes._core_C.ScalarType btype" - ") -> str[]"); - ops.def( - "machete_gemm(Tensor A, Tensor B," - " __torch__.torch.classes._core_C.ScalarType btype," - " Tensor? scales, Tensor? zeros, int? group_size," + "machete_gemm(Tensor A, Tensor B, int btype, " + " Tensor? scales, Tensor? zeros, int? group_size, " " Tensor? C, float? alpha, float? beta, str? schedule)" "-> Tensor"); - ops.def( - "machete_prepack_B(Tensor B," - " __torch__.torch.classes._core_C.ScalarType btype)" - "-> Tensor"); + ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor"); // conditionally compiled so impl registration is in source file ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); @@ -201,8 +195,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " - "__torch__.torch.classes._core_C.ScalarType b_q_type, " - "int size_m, int size_n, int size_k, bool is_k_full, " + "int b_q_type, " + "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "bool has_zp, bool use_fp32_reduce) -> Tensor"); // conditionally compiled so impl registration is in source file @@ -219,32 +213,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // conditionally compiled so impl registrations are in source file // Dequantization for GGML. - ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor"); + ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor"); ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize); // mmvq kernel for GGML. ops.def( - "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) " + "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) " "-> Tensor"); ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8); // mmq kernel for GGML. - ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor"); + ops.def( + "ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. ops.def( "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, int num_bits, int size_m, int size_n, " - "int size_k) -> Tensor"); + "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, " + "SymInt size_k) -> Tensor"); // conditionally compiled so impl registration is in source file // marlin_qqq_gemm for QQQ. ops.def( "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " "Tensor s_tok, Tensor s_ch, Tensor s_group, " - "Tensor! workspace, int size_m, int size_n, " - "int size_k) -> Tensor"); + "Tensor! workspace, SymInt size_m, SymInt size_n, " + "SymInt size_k) -> Tensor"); // conditionally compiled so impl registration is in source file // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column diff --git a/python_only_dev.py b/python_only_dev.py index 72d4e78ee14f6..4ab203bb6f9d6 100644 --- a/python_only_dev.py +++ b/python_only_dev.py @@ -39,7 +39,6 @@ assert cwd != package_path, "should not import from the current directory" files_to_copy = [ "vllm/_C.abi3.so", - "vllm/_core_C.abi3.so", "vllm/_moe_C.abi3.so", "vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so", "vllm/vllm_flash_attn/flash_attn_interface.py", diff --git a/setup.py b/setup.py index 9ea4e85c07542..d1f4b7f1c1119 100644 --- a/setup.py +++ b/setup.py @@ -290,10 +290,6 @@ def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() -def _build_core_ext() -> bool: - return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu()) - - def get_hipcc_rocm_version(): # Run the hipcc --version command result = subprocess.run(['hipcc', '--version'], @@ -456,9 +452,6 @@ def get_requirements() -> List[str]: ext_modules = [] -if _build_core_ext(): - ext_modules.append(CMakeExtension(name="vllm._core_C")) - if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 5386eb0e3795d..c69343b51ae02 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -69,11 +69,11 @@ def check_full_graph_support(model, os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level) os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" - # Inductor doesn't support fp8/gptq_marlin_24 yet. + # Inductor doesn't support fp8 and the base meta llama uses too + # much memory. quantization = model_kwargs.get("quantization") - if (quantization == "fp8" or quantization == "gptq_marlin" - or quantization == "gptq_marlin_24" - ) and optimization_level >= CompilationLevel.INDUCTOR: + if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B") + and optimization_level >= CompilationLevel.INDUCTOR): return prompts = [ diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index 0fc2984a68ded..59c0a24753c3b 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor, w_q = w_q.t().contiguous().t() # convert to col major w_q_machete = ops.machete_prepack_B(w_q, wtype) - opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype)) + opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id)) return w_ref, w_q_machete, w_s, w_zp @@ -153,9 +153,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype, schedule=schedule, ) - opcheck(torch.ops._C.machete_gemm, - (a, w_q_machete, wtype, w_s, maybe_convert_zeropoints( - w_zp, w_s), group_size, None, None, None, schedule)) + opcheck( + torch.ops._C.machete_gemm, + (a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints( + w_zp, w_s), group_size, None, None, None, schedule)) # Relax atol as our reduction dim becomes larger (more rounding error) # Relax atol when we have zeropoints since the way machete applies diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index a9bb72156c39e..5cfd4d6da7a86 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -225,7 +225,7 @@ def test_gptq_marlin_gemm( opcheck( torch.ops._C.gptq_marlin_gemm, (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1], + workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1], a_input.shape[1], is_k_full, False, use_fp32_reduce), test_utils=DEFAULT_OPCHECK_TEST_UTILS) @@ -254,6 +254,16 @@ def test_gptq_marlin_gemm( assert max_diff < 0.04 +# TODO: find better way to test this? +@torch.compile(fullgraph=True) +def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s, scratch, quant_type, size_m, size_n, + size_k): + return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s, scratch, quant_type, size_m, + size_n, size_k) + + @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, opcheck(torch.ops._C.gptq_marlin_24_gemm, (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, - workspace_24.scratch, quant_type, a_input.shape[0], + workspace_24.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1], a_input.shape[1]), test_utils=DEFAULT_OPCHECK_TEST_UTILS) - output = ops.gptq_marlin_24_gemm( + output = marlin_24_gemm_tester( a_input, marlin_24_q_w_comp, marlin_24_meta, diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index b73c45b9cd198..b87fbc3f1937e 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -240,8 +240,8 @@ def test_fused_marlin_moe( requires_grad=False) opcheck(torch.ops._moe_C.marlin_gemm_moe, (a, qweight1, sorted_token_ids, topk_weights, topk_ids, - scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m, - 2 * n, k, True, e, topk, block_size_m, True, False)) + scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id, + m, 2 * n, k, True, e, topk, block_size_m, True, False)) @pytest.mark.skip("This test is here for the sake of debugging, " diff --git a/tests/test_scalartype.py b/tests/test_scalartype.py index 1201aaa92ea89..a9221f08c2946 100644 --- a/tests/test_scalartype.py +++ b/tests/test_scalartype.py @@ -32,5 +32,5 @@ def test_scalar_type_min_max(type_tuple): max = torch.iinfo(torch_type).max print(t, min, max, t.min(), t.max()) - assert min == t.min() - assert max == t.max() + assert min == t.min(), f"min: {min} != {t.min()}" + assert max == t.max(), f"max: {max} != {t.max()}" diff --git a/tools/report_build_time_ninja.py b/tools/report_build_time_ninja.py index 3f9b68c2eccbe..33431a33ac837 100644 --- a/tools/report_build_time_ninja.py +++ b/tools/report_build_time_ninja.py @@ -16,7 +16,6 @@ Typical output looks like this: 2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time) 3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time) Longest build steps for .so (linking): - 0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time) 0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time) 0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time) 6.2 weighted s to build _C.abi3.so (6.2 s elapsed time) diff --git a/vllm/_core_ext.py b/vllm/_core_ext.py deleted file mode 100644 index a27b8648bee47..0000000000000 --- a/vllm/_core_ext.py +++ /dev/null @@ -1,278 +0,0 @@ -import importlib.util -from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union - -import torch - -from vllm.logger import init_logger - -logger = init_logger(__name__) -core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -if TYPE_CHECKING or not core_C_available: - # On platforms were we cannot use/build the C++ core extension (i.e. namely - # neuron and tpu), we define the mock ScalarType class here that partially - # mimics the C++ ScalarType class. - # - # We also use this provide type signatures to the Python LSP for the methods - # in the C++ ScalarType class. So these type signatures should be kept - # in sync with csrc/core/scalar_type.hpp - - from dataclasses import dataclass - - @dataclass(frozen=True) - class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - _finite_values_only: bool = False - """ - Private: if NANs are supported, used `has_infs()` instead. - """ - - nan_repr: int = NanRepr.IEEE_754.value - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - @property - def size_bits(self): - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - raise NotImplementedError - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - raise NotImplementedError - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - ... - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only - - def __str__(self) -> str: - raise NotImplementedError - - def __repr__(self) -> str: - raise NotImplementedError - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - "Create a signed integer scalar type (size_bits includes sign-bit)." - return cls(size_bits - 1, size_bits, bias if bias else 0, True) - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" - return cls(size_bits, size_bits, bias if bias else 0, False) - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - return cls(exponent, mantissa, 0, True) - - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: int) -> 'ScalarType': - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - return cls(exponent, mantissa, 0, True, finite_values_only, - nan_repr) - -elif core_C_available: - try: - import vllm._core_C # noqa: F401 - except ImportError as e: - logger.warning("Failed to import from vllm._core_C with %r", e) - - ScalarType = torch.classes._core_C.ScalarType - - if (hasattr(torch, "_library") - and hasattr(torch._library, "register_fake_class")): - # Needed for dynamo support of ScalarType. - @torch._library.register_fake_class("_core_C::ScalarType") - class FakeScalarType: - - def __init__(self, scalar_type): - self.ScalarType = scalar_type - - def bias_getter(self) -> int: - return self.ScalarType.bias - - def exponent_getter(self) -> int: - return self.ScalarType.exponent - - def mantissa_getter(self) -> int: - return self.ScalarType.mantissa - - def signed_getter(self) -> bool: - return self.ScalarType.signed - - def size_bits_getter(self) -> int: - return self.ScalarType.size_bits - - @property - def size_bits(self) -> int: - return self.ScalarType.size_bits - - def min(self) -> Union[int, float]: - return self.ScalarType.min() - - def max(self) -> Union[int, float]: - return self.ScalarType.max() - - def is_signed(self) -> bool: - return self.ScalarType.is_signed() - - def is_floating_point(self) -> bool: - return self.ScalarType.is_floating_point() - - def is_integer(self) -> bool: - return self.ScalarType.is_integer() - - def has_bias(self) -> bool: - return self.ScalarType.has_bias() - - def has_infs(self) -> bool: - return self.ScalarType.has_infs() - - def has_nans(self) -> bool: - return self.ScalarType.has_nans() - - def is_ieee_754(self) -> bool: - return self.ScalarType.is_ieee_754() - - def __str__(self) -> str: - return self.ScalarType.__str__() - - def __repr__(self) -> str: - return self.ScalarType.__repr__() - - def __len__(self) -> int: - return self.ScalarType.__len__() - - def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]: - return torch.classes._core_C.ScalarType.__obj_flatten__( - self.ScalarType) - - @classmethod - def __obj_unflatten__( - cls, flat_type: Tuple[Tuple[str, Any], - ...]) -> 'ScalarType': - return cls( - torch.classes._core_C.ScalarType.__obj_unflatten__( - flat_type)) - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - return ScalarType.int_(size_bits, bias) - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - return ScalarType.uint(size_bits, bias) - - @classmethod - def float_IEEE754(cls, exponent: int, - mantissa: int) -> 'ScalarType': - return ScalarType.float_IEEE754(exponent, mantissa) - - @classmethod - def float_(cls, exponent: int, mantissa: int, - finite_values_only: bool, - nan_repr: int) -> 'ScalarType': - return ScalarType.float_(exponent, mantissa, - finite_values_only, nan_repr) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ec035f137c3a6..b2952bbfa917c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -6,9 +6,9 @@ import torch import torch.library import vllm.envs as envs -from vllm._core_ext import ScalarType from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType logger = init_logger(__name__) @@ -306,7 +306,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, workspace: torch.Tensor, b_q_type: ScalarType, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, b_q_type, size_m, + workspace, b_q_type.id, size_m, size_n, size_k) @@ -316,8 +316,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, - b_q_type: ScalarType, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake("_C::gptq_marlin_gemm") @@ -329,17 +330,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): perm: torch.Tensor, workspace: torch.Tensor, b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, is_k_full: bool, has_zp: bool = False, use_fp32_reduce: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake("_C::ggml_dequantize") - def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int, - n: int) -> torch.Tensor: + def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, + m: torch.SymInt, + n: torch.SymInt) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) @register_fake("_C::ggml_mul_mat_vec_a8") @@ -347,7 +349,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): W: torch.Tensor, X: torch.Tensor, quant_type: int, - row: int, + row: torch.SymInt, ) -> torch.Tensor: return torch.empty((1, row), dtype=torch.float16, device=W.device) @@ -356,7 +358,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): W: torch.Tensor, X: torch.Tensor, quant_type: int, - row: int, + row: torch.SymInt, ) -> torch.Tensor: batch = X.size(0) return torch.empty((batch, row), dtype=torch.float16, device=W.device) @@ -365,8 +367,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, s_group: torch.Tensor, workspace: torch.Tensor, - size_m: int, size_n: int, - size_k: int) -> torch.Tensor: + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=torch.float16, device=a.device) @@ -374,16 +376,16 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @register_fake("_C::marlin_gemm") def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: int, size_n: int, - size_k: int) -> torch.Tensor: + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=torch.float16, device=a.device) @register_fake("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, - zeros: torch.Tensor, split_k_iters: int, thx: int, - thy: int) -> torch.Tensor: + zeros: torch.Tensor, split_k_iters: torch.SymInt, + thx: int, thy: int) -> torch.Tensor: in_c = qweight.size(0) qout_c = qweight.size(1) out_c = qout_c * 8 @@ -394,7 +396,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @register_fake("_C::awq_gemm") def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, - split_k_iters: int) -> torch.Tensor: + split_k_iters: torch.SymInt) -> torch.Tensor: num_in_feats = input.size(0) return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8), dtype=input.dtype, @@ -429,8 +431,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @register_fake("_C::fp8_marlin_gemm") def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) @register_fake("_C::machete_gemm") @@ -457,40 +460,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - @register_fake("_C::causal_conv1d_fwd") - def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], - conv_states: Optional[torch.Tensor], - cu_seq_len: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - silu_activation: bool, pad_slot_id: int): - return None - - @register_fake("_C::causal_conv1d_update") - def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, - bias_: Optional[torch.Tensor], - silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor], - pad_slot_id: int) -> None: - return None - - @register_fake("_C::selective_scan_fwd") - def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, - A: torch.Tensor, B: torch.Tensor, - C: torch.Tensor, D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], - delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, - cu_seq_len: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - ssm_states: Optional[torch.Tensor], - pad_slot_id: int) -> None: - return None - # cutlass def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: @@ -611,7 +580,7 @@ def gptq_marlin_gemm(a: torch.Tensor, has_zp: bool = False, use_fp32_reduce: bool = False) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, - g_idx, perm, workspace, b_q_type, + g_idx, perm, workspace, b_q_type.id, size_m, size_n, size_k, is_k_full, has_zp, use_fp32_reduce) @@ -627,7 +596,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # machete def machete_supported_schedules(b_type: ScalarType) -> List[str]: - return torch.ops._C.machete_supported_schedules(b_type) + return torch.ops._C.machete_supported_schedules(b_type.id) def machete_gemm( @@ -642,13 +611,13 @@ def machete_gemm( beta: Optional[float] = None, schedule: Optional[str] = None, ) -> torch.Tensor: - return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros, + return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros, b_group_size, c, alpha, beta, schedule) def machete_prepack_B(b_q_weight: torch.Tensor, b_type: ScalarType) -> torch.Tensor: - return torch.ops._C.machete_prepack_B(b_q_weight, b_type) + return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id) if hasattr(torch.ops._C, "permute_cols"): @@ -862,10 +831,10 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): topk_ids: torch.Tensor, b_scales: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, - b_q_type: ScalarType, size_m: int, size_n: int, - size_k: int, is_k_full: bool, num_experts: int, - topk: int, moe_block_size: int, - replicate_input: bool, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, size_k: torch.SymInt, + is_k_full: bool, num_experts: int, topk: int, + moe_block_size: int, replicate_input: bool, apply_weights: bool) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 5964d5a5465fd..5ae40a2af5a2b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -116,7 +116,7 @@ def single_marlin_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K, + w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K, is_k_full, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -272,7 +272,7 @@ def fused_marlin_moe( g_idx1, sort_indices1, workspace, - scalar_type1, + scalar_type1.id, M, 2 * N, K, @@ -297,7 +297,7 @@ def fused_marlin_moe( g_idx2, sort_indices2, workspace, - scalar_type2, + scalar_type2.id, M, K, N, diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index 373151a5311e5..9d711b0debcd8 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -1,4 +1,298 @@ -from ._core_ext import NanRepr, ScalarType +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: @@ -17,14 +311,13 @@ class scalar_types: uint4 = ScalarType.uint(4, None) int8 = ScalarType.int_(8, None) uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, - NanRepr.EXTD_RANGE_MAX_MIN.value) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) float8_e5m2 = ScalarType.float_IEEE754(5, 2) float16_e8m7 = ScalarType.float_IEEE754(8, 7) float16_e5m10 = ScalarType.float_IEEE754(5, 10) # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value) + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) # "gptq" types uint2b2 = ScalarType.uint(2, 2)