mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 05:42:16 +08:00
[BugFix][Kernel]: fix illegal memory access in causal_conv1d when conv_states is None (#10928)
Signed-off-by: xffxff <1247714429@qq.com>
This commit is contained in:
parent
c889d5888b
commit
78029b34ed
@ -424,7 +424,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
|||||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||||
// (which occurs when `final_state_position` is a non-positivie index)
|
// (which occurs when `final_state_position` is a non-positivie index)
|
||||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||||
if (final_state_position < 0 && seqlen > kWidth){
|
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||||
input_t vals_load[kNElts] = {0};
|
input_t vals_load[kNElts] = {0};
|
||||||
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
||||||
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
||||||
|
|||||||
@ -149,13 +149,14 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
|
|||||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||||
@pytest.mark.parametrize("silu_activation", [True])
|
@pytest.mark.parametrize("silu_activation", [True])
|
||||||
@pytest.mark.parametrize("has_bias", [True])
|
@pytest.mark.parametrize("has_bias", [True])
|
||||||
|
@pytest.mark.parametrize("has_initial_state", [True, False])
|
||||||
@pytest.mark.parametrize("width", [4])
|
@pytest.mark.parametrize("width", [4])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
|
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
|
||||||
@pytest.mark.parametrize('dim', [64])
|
@pytest.mark.parametrize('dim', [64])
|
||||||
@pytest.mark.parametrize('batch', [1])
|
@pytest.mark.parametrize('batch', [1])
|
||||||
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||||
itype):
|
has_initial_state, itype):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||||
if itype == torch.bfloat16:
|
if itype == torch.bfloat16:
|
||||||
@ -167,11 +168,18 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
|||||||
|
|
||||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||||
initial_states = torch.randn(batch,
|
if has_initial_state:
|
||||||
dim,
|
initial_states = torch.randn(batch,
|
||||||
width - 1,
|
dim,
|
||||||
device=device,
|
width - 1,
|
||||||
dtype=itype)
|
device=device,
|
||||||
|
dtype=itype)
|
||||||
|
has_initial_state_tensor = torch.ones(batch,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=x.device)
|
||||||
|
else:
|
||||||
|
initial_states = None
|
||||||
|
has_initial_state_tensor = None
|
||||||
x_ref = x.clone()
|
x_ref = x.clone()
|
||||||
weight_ref = weight.clone()
|
weight_ref = weight.clone()
|
||||||
bias_ref = bias.clone() if bias is not None else None
|
bias_ref = bias.clone() if bias is not None else None
|
||||||
@ -183,9 +191,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
|||||||
bias,
|
bias,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
conv_states=initial_states,
|
conv_states=initial_states,
|
||||||
has_initial_state=torch.ones(batch,
|
has_initial_state=has_initial_state_tensor)
|
||||||
dtype=torch.bool,
|
|
||||||
device=x.device))
|
|
||||||
out_ref, final_states_ref = causal_conv1d_ref(
|
out_ref, final_states_ref = causal_conv1d_ref(
|
||||||
x_ref,
|
x_ref,
|
||||||
weight_ref,
|
weight_ref,
|
||||||
@ -193,11 +199,12 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
|||||||
initial_states=initial_states_ref,
|
initial_states=initial_states_ref,
|
||||||
return_final_states=True,
|
return_final_states=True,
|
||||||
activation=activation)
|
activation=activation)
|
||||||
assert initial_states is not None and final_states_ref is not None
|
if has_initial_state:
|
||||||
assert torch.allclose(initial_states,
|
assert initial_states is not None and final_states_ref is not None
|
||||||
final_states_ref,
|
assert torch.allclose(initial_states,
|
||||||
rtol=rtol,
|
final_states_ref,
|
||||||
atol=atol)
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
causal_conv1d_opcheck_fn(x,
|
causal_conv1d_opcheck_fn(x,
|
||||||
@ -205,9 +212,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
|||||||
bias,
|
bias,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
conv_states=initial_states,
|
conv_states=initial_states,
|
||||||
has_initial_state=torch.ones(batch,
|
has_initial_state=has_initial_state_tensor)
|
||||||
dtype=torch.bool,
|
|
||||||
device=x.device))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user