diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 0c6e0dfefd8a..f37a829f401c 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -947,6 +947,7 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): nums_dict = {} # type: ignore batch_ptr = None token_chunk_offset_ptr = None + device = query_start_loc_p.device for BLOCK_M in [8]: # cover all BLOCK_M values nums = -(-seqlens // BLOCK_M) nums_dict[BLOCK_M] = {} @@ -968,11 +969,11 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), PAD_SLOT_ID, dtype=torch.int32, - device='cuda') + device=device) token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ), PAD_SLOT_ID, dtype=torch.int32, - device='cuda') + device=device) else: if batch_ptr.nelement() < MAX_NUM_PROGRAMS: batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)