[CI] fix mamba kernel test (#26250)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-10-06 02:26:59 +08:00 committed by GitHub
parent 512b8affa4
commit 9c3c21c519
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 1 deletions

View File

@ -477,6 +477,7 @@ steps:
source_file_dependencies:
- csrc/mamba/
- tests/kernels/mamba
- vllm/model_executor/layers/mamba/ops
commands:
- pytest -v -s kernels/mamba

View File

@ -165,7 +165,17 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=conv_state_indices,
)
out_ref = causal_conv1d_update_ref(
x_ref, conv_state_ref, weight, bias, activation=activation
)