mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 08:35:01 +08:00
[Minor] More fix of test_cache.py CI test failure (#2750)
This commit is contained in:
parent
ed70c70ea3
commit
fe6d09ae61
@ -181,16 +181,15 @@ def test_swap_blocks(
|
|||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
device: int,
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
src_device = f"{direction[0]}:{device}" if direction[
|
|
||||||
0] == "cuda" else direction[0]
|
src_device = device if direction[0] == "cuda" else 'cpu'
|
||||||
dst_device = f"{direction[1]}:{device}" if direction[
|
dst_device = device if direction[1] == "cuda" else 'cpu'
|
||||||
1] == "cuda" else direction[1]
|
|
||||||
|
|
||||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||||
# For the same device, mapping must not overlap
|
# For the same device, mapping must not overlap
|
||||||
|
|||||||
@ -258,10 +258,13 @@ def create_kv_caches_with_random(
|
|||||||
key_cache = torch.empty(size=key_cache_shape,
|
key_cache = torch.empty(size=key_cache_shape,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
device=device)
|
device=device)
|
||||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
if cache_dtype == 'fp8_e5m2':
|
||||||
key_cache.uniform_(-scale, scale)
|
|
||||||
elif cache_dtype == 'fp8_e5m2':
|
|
||||||
_generate_random_fp8_e5m2(key_cache, -scale, scale)
|
_generate_random_fp8_e5m2(key_cache, -scale, scale)
|
||||||
|
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||||
|
key_cache.uniform_(-scale, scale)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Does not support key cache of type {cache_dtype}")
|
||||||
key_caches.append(key_cache)
|
key_caches.append(key_cache)
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||||
@ -270,9 +273,12 @@ def create_kv_caches_with_random(
|
|||||||
value_cache = torch.empty(size=value_cache_shape,
|
value_cache = torch.empty(size=value_cache_shape,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
device=device)
|
device=device)
|
||||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
if cache_dtype == 'fp8_e5m2':
|
||||||
value_cache.uniform_(-scale, scale)
|
|
||||||
elif cache_dtype == 'fp8_e5m2':
|
|
||||||
_generate_random_fp8_e5m2(value_cache, -scale, scale)
|
_generate_random_fp8_e5m2(value_cache, -scale, scale)
|
||||||
|
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||||
|
value_cache.uniform_(-scale, scale)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Does not support value cache of type {cache_dtype}")
|
||||||
value_caches.append(value_cache)
|
value_caches.append(value_cache)
|
||||||
return key_caches, value_caches
|
return key_caches, value_caches
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user