mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
[CI] fix mamba kernel test (#26250)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
512b8affa4
commit
9c3c21c519
@ -477,6 +477,7 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/mamba/
|
- csrc/mamba/
|
||||||
- tests/kernels/mamba
|
- tests/kernels/mamba
|
||||||
|
- vllm/model_executor/layers/mamba/ops
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s kernels/mamba
|
- pytest -v -s kernels/mamba
|
||||||
|
|
||||||
|
|||||||
@ -165,7 +165,17 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
|
|||||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||||
conv_state_ref = conv_state.detach().clone()
|
conv_state_ref = conv_state.detach().clone()
|
||||||
activation = None if not silu_activation else "silu"
|
activation = None if not silu_activation else "silu"
|
||||||
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
|
|
||||||
|
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
out = causal_conv1d_update(
|
||||||
|
x,
|
||||||
|
conv_state,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
activation=activation,
|
||||||
|
conv_state_indices=conv_state_indices,
|
||||||
|
)
|
||||||
out_ref = causal_conv1d_update_ref(
|
out_ref = causal_conv1d_update_ref(
|
||||||
x_ref, conv_state_ref, weight, bias, activation=activation
|
x_ref, conv_state_ref, weight, bias, activation=activation
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user