mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 12:55:02 +08:00
[Kernel] Pass a device pointer into the quantize kernel for the scales (#5159)
This commit is contained in:
parent
0ab278ca31
commit
cbb2f59cc8
@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
float scale);
|
torch::Tensor const& scale);
|
||||||
|
|
||||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
torch::Tensor lookup_table);
|
torch::Tensor lookup_table);
|
||||||
|
|||||||
@ -28,9 +28,10 @@ namespace vllm {
|
|||||||
template <typename scalar_t, typename scale_type>
|
template <typename scalar_t, typename scale_type>
|
||||||
__global__ void static_scaled_int8_quant_kernel(
|
__global__ void static_scaled_int8_quant_kernel(
|
||||||
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
|
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
|
||||||
scale_type scale, const int hidden_size) {
|
const scale_type* scale_ptr, const int hidden_size) {
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
|
scale_type scale = *scale_ptr;
|
||||||
|
|
||||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
out[token_idx * hidden_size + i] =
|
out[token_idx * hidden_size + i] =
|
||||||
@ -40,10 +41,12 @@ __global__ void static_scaled_int8_quant_kernel(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor const& input, // [..., hidden_size]
|
||||||
float scale) {
|
torch::Tensor const& scale) {
|
||||||
TORCH_CHECK(input.is_contiguous());
|
TORCH_CHECK(input.is_contiguous());
|
||||||
TORCH_CHECK(out.is_contiguous());
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
TORCH_CHECK(scale.numel() == 1);
|
||||||
|
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
@ -53,7 +56,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
|||||||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||||
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
||||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||||
out.data_ptr<int8_t>(), scale,
|
out.data_ptr<int8_t>(),
|
||||||
hidden_size);
|
scale.data_ptr<float>(), hidden_size);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
|||||||
torch.iinfo(torch.int8).min,
|
torch.iinfo(torch.int8).min,
|
||||||
torch.iinfo(torch.int8).max).to(torch.int8)
|
torch.iinfo(torch.int8).max).to(torch.int8)
|
||||||
out2 = torch.empty_like(x, dtype=torch.int8)
|
out2 = torch.empty_like(x, dtype=torch.int8)
|
||||||
ops.static_scaled_int8_quant(out2, x, scale)
|
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
ops.static_scaled_int8_quant(out2, x, scale_argument)
|
||||||
assert torch.allclose(out1, out2,
|
assert torch.allclose(out1, out2,
|
||||||
atol=1) # big atol to account for rounding errors
|
atol=1) # big atol to account for rounding errors
|
||||||
|
|||||||
@ -265,7 +265,7 @@ def scaled_fp8_quant(
|
|||||||
|
|
||||||
# int8
|
# int8
|
||||||
def static_scaled_int8_quant(input: torch.Tensor,
|
def static_scaled_int8_quant(input: torch.Tensor,
|
||||||
scale: float) -> torch.Tensor:
|
scale: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Quantize the input tensor to int8 and return the quantized tensor.
|
Quantize the input tensor to int8 and return the quantized tensor.
|
||||||
|
|
||||||
|
|||||||
@ -97,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
|||||||
act_scale = layer.input_scale
|
act_scale = layer.input_scale
|
||||||
|
|
||||||
# Input quantize
|
# Input quantize
|
||||||
x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item())
|
x_q = custom_ops.static_scaled_int8_quant(x, act_scale)
|
||||||
|
|
||||||
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
|
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
|
||||||
weight_scale, x.dtype)
|
weight_scale, x.dtype)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user