diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index 8fa70054f02c..ba40d42307fa 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -618,7 +618,9 @@ if triton.__version__ >= "2.1.0": b_ctx_len, max_input_len, alibi_slopes=None): - BLOCK = 128 + + cap = torch.cuda.get_device_capability() + BLOCK = 128 if cap[0] >= 8 else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv