mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:14:57 +08:00
[CI] fix mamba kernel test (#26250)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
512b8affa4
commit
9c3c21c519
@ -477,6 +477,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- csrc/mamba/
|
||||
- tests/kernels/mamba
|
||||
- vllm/model_executor/layers/mamba/ops
|
||||
commands:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
|
||||
@ -165,7 +165,17 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
|
||||
|
||||
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
|
||||
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=conv_state_indices,
|
||||
)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref, conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user