[Kernel] correct cpu worker function parameter type (#19745)

Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
Ning Xie 2025-06-20 18:50:13 +08:00 committed by GitHub
parent e384f2f108
commit 71d1219545
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 5 deletions

View File

@ -29,7 +29,7 @@ class _PagedAttention:
head_size: int, head_size: int,
*args, *args,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size) return 2, num_blocks, block_size * num_kv_heads * head_size
@staticmethod @staticmethod
def split_kv_cache( def split_kv_cache(

View File

@ -3,7 +3,7 @@
"""A CPU worker class.""" """A CPU worker class."""
import os import os
from importlib import util from importlib import util
from typing import Dict, List, Optional, Set, Tuple, Type from typing import List, Optional, Set, Tuple, Type
import torch import torch
import torch.distributed import torch.distributed
@ -88,13 +88,13 @@ class CPUCacheEngine:
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
return kv_cache return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None: def swap_in(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.") raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def swap_out(self, src_to_dst: Dict[int, int]) -> None: def swap_out(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.") raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
@staticmethod @staticmethod