mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[BugFix][Kernel] Fix Illegal memory access in causal_conv1d in H100 (#9838)
Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
parent
55650c83a0
commit
9fb12f7848
@ -418,6 +418,31 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
|
||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||
// in case the final state is separated between the last "smem_exchange" and
|
||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||
// (which occurs when `final_state_position` is a non-positivie index)
|
||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||
if (final_state_position < 0 && seqlen > kWidth){
|
||||
input_t vals_load[kNElts] = {0};
|
||||
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
||||
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < -final_state_position; ++w){
|
||||
conv_states[w] = vals_load[kNElts + final_state_position + w];
|
||||
}
|
||||
}
|
||||
if ((chunk == n_chunks - 1) && tidx == 0){
|
||||
// chunk = n_chunks - 1, the second segment of the final state first positions
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
|
||||
for (int w = -final_state_position; w < kWidth - 1; ++w){
|
||||
conv_states[w] = vals_load[w + final_state_position];
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Final state is stored in the smem_exchange last token slot,
|
||||
// in case seqlen < kWidth, we would need to take the final state from the
|
||||
@ -446,9 +471,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
}
|
||||
else {
|
||||
// in case the final state is in between the threads data
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
||||
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
|
||||
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
|
||||
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
|
||||
// illegal access error on H100.
|
||||
// Therefore, we access last_thread + 1, only if the final state data sits there
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
|
||||
@ -151,7 +151,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize(
|
||||
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
||||
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
|
||||
@pytest.mark.parametrize('dim', [64])
|
||||
@pytest.mark.parametrize('batch', [1])
|
||||
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
@ -420,7 +420,10 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
|
||||
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(final_states[state_indices],
|
||||
final_states_ref[state_indices],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
||||
padded_state_indices, has_initial_states,
|
||||
|
||||
@ -555,7 +555,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 7e-2, 7e-2
|
||||
rtol, atol = 1e-1, 1e-1
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
@ -610,8 +610,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
|
||||
print("Output diff max", (out - out_ref[0]).max())
|
||||
print("Output diff mean", (out - out_ref[0]).mean())
|
||||
print("Output diff max", (out[:batch_size] - out_ref).max())
|
||||
print("Output diff mean", (out[:batch_size] - out_ref).mean())
|
||||
print("Output state diff max", (state[state_indices, :] - state_ref).max())
|
||||
print("Output state diff mean",
|
||||
(state[state_indices, :] - state_ref).mean())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user