diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index e5a5c9dd6f712..661c884627b00 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -245,7 +245,7 @@ def _chunk_scan_fwd_kernel( ) if not HAS_INITSTATES and (seq_idx != seq_idx_prev): prev_states = tl.zeros( - (BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty + (BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty ) else: prev_states = tl.load(