mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Kernel][Misc] dynamo support for ScalarType (#7594)
This commit is contained in:
parent
9f69856356
commit
7759ae958f
@ -313,6 +313,8 @@ class ScalarType {
|
|||||||
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
|
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
|
||||||
// constructor at the same time (torch::CustomClassHolder does not have a
|
// constructor at the same time (torch::CustomClassHolder does not have a
|
||||||
// constexpr destructor)
|
// constexpr destructor)
|
||||||
|
// See also:
|
||||||
|
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
||||||
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||||
public:
|
public:
|
||||||
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
|
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
|
||||||
@ -382,6 +384,29 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This needs to be implemented and throw a TypeError in order for
|
||||||
|
// PyTorch's opcheck to work on ops that use ScalarTypes.
|
||||||
|
int64_t len() const {
|
||||||
|
throw c10::TypeError("__len__ not implemented");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize a ScalarType into a tuple of pairs. Where each pair
|
||||||
|
// is a (fieldname, value).
|
||||||
|
// For simplicity, we are just going to convert to a ScalarTypeId.
|
||||||
|
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
|
||||||
|
return {{"ScalarType", id()}};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deserialize a scalar type that has been serialized by obj_flatten,
|
||||||
|
// ostensibly from a tuple of (member name, value) pairs, but in reality
|
||||||
|
// just a ScalarTypeId.
|
||||||
|
static SelfPtr obj_unflatten(
|
||||||
|
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
|
||||||
|
return c10::make_intrusive<Self>(
|
||||||
|
from_id(std::get<1>(std::get<0>(flat_type))));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void bind_readonly_property(torch::class_<Self>& cls,
|
static void bind_readonly_property(torch::class_<Self>& cls,
|
||||||
std::string const& name, T Base::*field) {
|
std::string const& name, T Base::*field) {
|
||||||
@ -457,6 +482,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
self.get()->min());
|
self.get()->min());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
bind_function(cls, "__len__", &ScalarTypeTorch::len);
|
||||||
bind_function(cls, "__str__", &Base::str);
|
bind_function(cls, "__str__", &Base::str);
|
||||||
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
|
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
|
||||||
return *self == *other;
|
return *self == *other;
|
||||||
@ -465,6 +491,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
|||||||
return "ScalarType." + self.get()->str();
|
return "ScalarType." + self.get()->str();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
|
||||||
|
bind_static_function(cls, "__obj_unflatten__",
|
||||||
|
&ScalarTypeTorch::obj_unflatten);
|
||||||
|
|
||||||
// Bind static functions (convenience constructors)
|
// Bind static functions (convenience constructors)
|
||||||
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
|
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
|
||||||
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
|
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -31,14 +31,14 @@ if TYPE_CHECKING or not core_C_available:
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ScalarType:
|
class ScalarType:
|
||||||
"""
|
"""
|
||||||
ScalarType can represent a wide range of floating point and integer
|
ScalarType can represent a wide range of floating point and integer
|
||||||
types, in particular it can be used to represent sub-byte data types
|
types, in particular it can be used to represent sub-byte data types
|
||||||
(something that torch.dtype currently does not support). It is also
|
(something that torch.dtype currently does not support). It is also
|
||||||
capable of representing types with a bias, i.e.:
|
capable of representing types with a bias, i.e.:
|
||||||
`stored_value = value + bias`,
|
`stored_value = value + bias`,
|
||||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||||
of 8). The implementation for this class can be found in
|
of 8). The implementation for this class can be found in
|
||||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||||
with that file.
|
with that file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -51,15 +51,15 @@ if TYPE_CHECKING or not core_C_available:
|
|||||||
mantissa: int
|
mantissa: int
|
||||||
"""
|
"""
|
||||||
Number of bits in the mantissa if this is a floating point type,
|
Number of bits in the mantissa if this is a floating point type,
|
||||||
or the number bits representing an integer excluding the sign bit if
|
or the number bits representing an integer excluding the sign bit if
|
||||||
this an integer type.
|
this an integer type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bias: int
|
bias: int
|
||||||
"""
|
"""
|
||||||
bias used to encode the values in this scalar type
|
bias used to encode the values in this scalar type
|
||||||
(value = stored_value - bias, default 0) for example if we store the
|
(value = stored_value - bias, default 0) for example if we store the
|
||||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -73,7 +73,7 @@ if TYPE_CHECKING or not core_C_available:
|
|||||||
|
|
||||||
nan_repr: int = NanRepr.IEEE_754.value
|
nan_repr: int = NanRepr.IEEE_754.value
|
||||||
"""
|
"""
|
||||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||||
(not applicable for integer types)
|
(not applicable for integer types)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -83,14 +83,14 @@ if TYPE_CHECKING or not core_C_available:
|
|||||||
|
|
||||||
def min(self) -> Union[int, float]:
|
def min(self) -> Union[int, float]:
|
||||||
"""
|
"""
|
||||||
Min representable value for this scalar type.
|
Min representable value for this scalar type.
|
||||||
(accounting for bias if there is one)
|
(accounting for bias if there is one)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def max(self) -> Union[int, float]:
|
def max(self) -> Union[int, float]:
|
||||||
"""
|
"""
|
||||||
Max representable value for this scalar type.
|
Max representable value for this scalar type.
|
||||||
(accounting for bias if there is one)
|
(accounting for bias if there is one)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -103,28 +103,28 @@ if TYPE_CHECKING or not core_C_available:
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def is_floating_point(self):
|
def is_floating_point(self) -> bool:
|
||||||
"If the type is a floating point type"
|
"If the type is a floating point type"
|
||||||
return self.exponent != 0
|
return self.exponent != 0
|
||||||
|
|
||||||
def is_integer(self):
|
def is_integer(self) -> bool:
|
||||||
"If the type is an integer type"
|
"If the type is an integer type"
|
||||||
return self.exponent == 0
|
return self.exponent == 0
|
||||||
|
|
||||||
def has_bias(self):
|
def has_bias(self) -> bool:
|
||||||
"If the type has a non-zero bias"
|
"If the type has a non-zero bias"
|
||||||
return self.bias != 0
|
return self.bias != 0
|
||||||
|
|
||||||
def has_infs(self):
|
def has_infs(self) -> bool:
|
||||||
"If the type is floating point and supports infinity"
|
"If the type is floating point and supports infinity"
|
||||||
return not self._finite_values_only
|
return not self._finite_values_only
|
||||||
|
|
||||||
def has_nans(self):
|
def has_nans(self) -> bool:
|
||||||
return self.nan_repr != NanRepr.NONE.value
|
return self.nan_repr != NanRepr.NONE.value
|
||||||
|
|
||||||
def is_ieee_754(self) -> bool:
|
def is_ieee_754(self) -> bool:
|
||||||
"""
|
"""
|
||||||
If the type is a floating point type that follows IEEE 754
|
If the type is a floating point type that follows IEEE 754
|
||||||
conventions
|
conventions
|
||||||
"""
|
"""
|
||||||
return self.nan_repr == NanRepr.IEEE_754.value and \
|
return self.nan_repr == NanRepr.IEEE_754.value and \
|
||||||
@ -136,6 +136,11 @@ if TYPE_CHECKING or not core_C_available:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||||
|
# opcheck to work.
|
||||||
|
def __len__(self) -> int:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
#
|
#
|
||||||
# Convenience Constructors
|
# Convenience Constructors
|
||||||
#
|
#
|
||||||
@ -153,16 +158,16 @@ if TYPE_CHECKING or not core_C_available:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
||||||
"""
|
"""
|
||||||
Create a standard floating point type
|
Create a standard floating point type
|
||||||
(i.e. follows IEEE 754 conventions).
|
(i.e. follows IEEE 754 conventions).
|
||||||
"""
|
"""
|
||||||
return cls(exponent, mantissa, 0, True)
|
return cls(exponent, mantissa, 0, True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
||||||
nan_repr: int):
|
nan_repr: int) -> 'ScalarType':
|
||||||
"""
|
"""
|
||||||
Create a non-standard floating point type
|
Create a non-standard floating point type
|
||||||
(i.e. does not follow IEEE 754 conventions).
|
(i.e. does not follow IEEE 754 conventions).
|
||||||
"""
|
"""
|
||||||
return cls(exponent, mantissa, 0, True, finite_values_only,
|
return cls(exponent, mantissa, 0, True, finite_values_only,
|
||||||
@ -175,3 +180,93 @@ elif core_C_available:
|
|||||||
logger.warning("Failed to import from vllm._core_C with %r", e)
|
logger.warning("Failed to import from vllm._core_C with %r", e)
|
||||||
|
|
||||||
ScalarType = torch.classes._core_C.ScalarType
|
ScalarType = torch.classes._core_C.ScalarType
|
||||||
|
|
||||||
|
# Needed for dynamo support of ScalarType.
|
||||||
|
@torch._library.register_fake_class("_core_C::ScalarType")
|
||||||
|
class FakeScalarType:
|
||||||
|
|
||||||
|
def __init__(self, scalar_type):
|
||||||
|
self.ScalarType = scalar_type
|
||||||
|
|
||||||
|
def bias_getter(self) -> int:
|
||||||
|
return self.ScalarType.bias
|
||||||
|
|
||||||
|
def exponent_getter(self) -> int:
|
||||||
|
return self.ScalarType.exponent
|
||||||
|
|
||||||
|
def mantissa_getter(self) -> int:
|
||||||
|
return self.ScalarType.mantissa
|
||||||
|
|
||||||
|
def signed_getter(self) -> bool:
|
||||||
|
return self.ScalarType.signed
|
||||||
|
|
||||||
|
def size_bits_getter(self) -> int:
|
||||||
|
return self.ScalarType.size_bits
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size_bits(self) -> int:
|
||||||
|
return self.ScalarType.size_bits
|
||||||
|
|
||||||
|
def min(self) -> Union[int, float]:
|
||||||
|
return self.ScalarType.min()
|
||||||
|
|
||||||
|
def max(self) -> Union[int, float]:
|
||||||
|
return self.ScalarType.max()
|
||||||
|
|
||||||
|
def is_signed(self) -> bool:
|
||||||
|
return self.ScalarType.is_signed()
|
||||||
|
|
||||||
|
def is_floating_point(self) -> bool:
|
||||||
|
return self.ScalarType.is_floating_point()
|
||||||
|
|
||||||
|
def is_integer(self) -> bool:
|
||||||
|
return self.ScalarType.is_integer()
|
||||||
|
|
||||||
|
def has_bias(self) -> bool:
|
||||||
|
return self.ScalarType.has_bias()
|
||||||
|
|
||||||
|
def has_infs(self) -> bool:
|
||||||
|
return self.ScalarType.has_infs()
|
||||||
|
|
||||||
|
def has_nans(self) -> bool:
|
||||||
|
return self.ScalarType.has_nans()
|
||||||
|
|
||||||
|
def is_ieee_754(self) -> bool:
|
||||||
|
return self.ScalarType.is_ieee_754()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.ScalarType.__str__()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.ScalarType.__repr__()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.ScalarType.__len__()
|
||||||
|
|
||||||
|
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
|
||||||
|
return torch.classes._core_C.ScalarType.__obj_flatten__(
|
||||||
|
self.ScalarType)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __obj_unflatten__(
|
||||||
|
cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType':
|
||||||
|
return cls(
|
||||||
|
torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||||
|
return ScalarType.int_(size_bits, bias)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||||
|
return ScalarType.uint(size_bits, bias)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
||||||
|
return ScalarType.float_IEEE754(exponent, mantissa)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
||||||
|
nan_repr: int) -> 'ScalarType':
|
||||||
|
return ScalarType.float_(exponent, mantissa, finite_values_only,
|
||||||
|
nan_repr)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user