diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index e0f4d9f3c5cf..b1e10fecb6b5 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -313,6 +313,8 @@ class ScalarType { // have ScalarType inherit from torch::CustomClassHolder and have a constexpr // constructor at the same time (torch::CustomClassHolder does not have a // constexpr destructor) +// See also: +// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { public: ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias, @@ -382,6 +384,29 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { exponent, mantissa, finite_values_only, NanRepr(nan_repr))); } + // This needs to be implemented and throw a TypeError in order for + // PyTorch's opcheck to work on ops that use ScalarTypes. + int64_t len() const { + throw c10::TypeError("__len__ not implemented"); + return 0; + } + + // Serialize a ScalarType into a tuple of pairs. Where each pair + // is a (fieldname, value). + // For simplicity, we are just going to convert to a ScalarTypeId. + std::tuple> obj_flatten() const { + return {{"ScalarType", id()}}; + } + + // Deserialize a scalar type that has been serialized by obj_flatten, + // ostensibly from a tuple of (member name, value) pairs, but in reality + // just a ScalarTypeId. + static SelfPtr obj_unflatten( + std::tuple> const& flat_type) { + return c10::make_intrusive( + from_id(std::get<1>(std::get<0>(flat_type)))); + } + template static void bind_readonly_property(torch::class_& cls, std::string const& name, T Base::*field) { @@ -457,6 +482,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { self.get()->min()); }); + bind_function(cls, "__len__", &ScalarTypeTorch::len); bind_function(cls, "__str__", &Base::str); bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) { return *self == *other; @@ -465,6 +491,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { return "ScalarType." + self.get()->str(); }); + bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten); + bind_static_function(cls, "__obj_unflatten__", + &ScalarTypeTorch::obj_unflatten); + // Bind static functions (convenience constructors) bind_static_function(cls, "int_", &ScalarTypeTorch::int_); bind_static_function(cls, "uint", &ScalarTypeTorch::uint); diff --git a/vllm/_core_ext.py b/vllm/_core_ext.py index e3b9fbb93891..aa520e1eafba 100644 --- a/vllm/_core_ext.py +++ b/vllm/_core_ext.py @@ -1,6 +1,6 @@ import importlib.util from enum import Enum -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union import torch @@ -31,14 +31,14 @@ if TYPE_CHECKING or not core_C_available: @dataclass(frozen=True) class ScalarType: """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync with that file. """ @@ -51,15 +51,15 @@ if TYPE_CHECKING or not core_C_available: mantissa: int """ Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if + or the number bits representing an integer excluding the sign bit if this an integer type. """ bias: int """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. """ @@ -73,7 +73,7 @@ if TYPE_CHECKING or not core_C_available: nan_repr: int = NanRepr.IEEE_754.value """ - How NaNs are represent in this scalar type, returns NanRepr value. + How NaNs are represent in this scalar type, returns NanRepr value. (not applicable for integer types) """ @@ -83,14 +83,14 @@ if TYPE_CHECKING or not core_C_available: def min(self) -> Union[int, float]: """ - Min representable value for this scalar type. + Min representable value for this scalar type. (accounting for bias if there is one) """ raise NotImplementedError def max(self) -> Union[int, float]: """ - Max representable value for this scalar type. + Max representable value for this scalar type. (accounting for bias if there is one) """ raise NotImplementedError @@ -103,28 +103,28 @@ if TYPE_CHECKING or not core_C_available: """ ... - def is_floating_point(self): + def is_floating_point(self) -> bool: "If the type is a floating point type" return self.exponent != 0 - def is_integer(self): + def is_integer(self) -> bool: "If the type is an integer type" return self.exponent == 0 - def has_bias(self): + def has_bias(self) -> bool: "If the type has a non-zero bias" return self.bias != 0 - def has_infs(self): + def has_infs(self) -> bool: "If the type is floating point and supports infinity" return not self._finite_values_only - def has_nans(self): + def has_nans(self) -> bool: return self.nan_repr != NanRepr.NONE.value def is_ieee_754(self) -> bool: """ - If the type is a floating point type that follows IEEE 754 + If the type is a floating point type that follows IEEE 754 conventions """ return self.nan_repr == NanRepr.IEEE_754.value and \ @@ -136,6 +136,11 @@ if TYPE_CHECKING or not core_C_available: def __repr__(self) -> str: raise NotImplementedError + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + # # Convenience Constructors # @@ -153,16 +158,16 @@ if TYPE_CHECKING or not core_C_available: @classmethod def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': """ - Create a standard floating point type + Create a standard floating point type (i.e. follows IEEE 754 conventions). """ return cls(exponent, mantissa, 0, True) @classmethod def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: int): + nan_repr: int) -> 'ScalarType': """ - Create a non-standard floating point type + Create a non-standard floating point type (i.e. does not follow IEEE 754 conventions). """ return cls(exponent, mantissa, 0, True, finite_values_only, @@ -175,3 +180,93 @@ elif core_C_available: logger.warning("Failed to import from vllm._core_C with %r", e) ScalarType = torch.classes._core_C.ScalarType + + # Needed for dynamo support of ScalarType. + @torch._library.register_fake_class("_core_C::ScalarType") + class FakeScalarType: + + def __init__(self, scalar_type): + self.ScalarType = scalar_type + + def bias_getter(self) -> int: + return self.ScalarType.bias + + def exponent_getter(self) -> int: + return self.ScalarType.exponent + + def mantissa_getter(self) -> int: + return self.ScalarType.mantissa + + def signed_getter(self) -> bool: + return self.ScalarType.signed + + def size_bits_getter(self) -> int: + return self.ScalarType.size_bits + + @property + def size_bits(self) -> int: + return self.ScalarType.size_bits + + def min(self) -> Union[int, float]: + return self.ScalarType.min() + + def max(self) -> Union[int, float]: + return self.ScalarType.max() + + def is_signed(self) -> bool: + return self.ScalarType.is_signed() + + def is_floating_point(self) -> bool: + return self.ScalarType.is_floating_point() + + def is_integer(self) -> bool: + return self.ScalarType.is_integer() + + def has_bias(self) -> bool: + return self.ScalarType.has_bias() + + def has_infs(self) -> bool: + return self.ScalarType.has_infs() + + def has_nans(self) -> bool: + return self.ScalarType.has_nans() + + def is_ieee_754(self) -> bool: + return self.ScalarType.is_ieee_754() + + def __str__(self) -> str: + return self.ScalarType.__str__() + + def __repr__(self) -> str: + return self.ScalarType.__repr__() + + def __len__(self) -> int: + return self.ScalarType.__len__() + + def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]: + return torch.classes._core_C.ScalarType.__obj_flatten__( + self.ScalarType) + + @classmethod + def __obj_unflatten__( + cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType': + return cls( + torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type)) + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + return ScalarType.int_(size_bits, bias) + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + return ScalarType.uint(size_bits, bias) + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + return ScalarType.float_IEEE754(exponent, mantissa) + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: int) -> 'ScalarType': + return ScalarType.float_(exponent, mantissa, finite_values_only, + nan_repr)