mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:26:11 +08:00
[Kernel] correct cpu worker function parameter type (#19745)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
parent
e384f2f108
commit
71d1219545
@ -29,7 +29,7 @@ class _PagedAttention:
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
*args,
|
*args,
|
||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
return 2, num_blocks, block_size * num_kv_heads * head_size
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def split_kv_cache(
|
def split_kv_cache(
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
"""A CPU worker class."""
|
"""A CPU worker class."""
|
||||||
import os
|
import os
|
||||||
from importlib import util
|
from importlib import util
|
||||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
from typing import List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -88,13 +88,13 @@ class CPUCacheEngine:
|
|||||||
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
|
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
|
||||||
return kv_cache
|
return kv_cache
|
||||||
|
|
||||||
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||||
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
||||||
|
|
||||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
||||||
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
||||||
|
|
||||||
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||||
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
|
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user