mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:26:11 +08:00
[V1][Kernel] Add triton implementation for reshape_and_cache_flash (#24503)
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
527821d191
commit
100b630a60
@ -9,6 +9,9 @@ import torch
|
|||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||||
|
triton_reshape_and_cache_flash,
|
||||||
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import (
|
from vllm.utils import (
|
||||||
@ -31,6 +34,8 @@ def run_benchmark(
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_cache_layout: str,
|
kv_cache_layout: str,
|
||||||
num_iters: int,
|
num_iters: int,
|
||||||
|
implementation: str,
|
||||||
|
benchmark_mode: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Return latency (seconds) for given num_tokens."""
|
"""Return latency (seconds) for given num_tokens."""
|
||||||
@ -38,6 +43,14 @@ def run_benchmark(
|
|||||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||||
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
|
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
|
||||||
|
|
||||||
|
if implementation not in ("cuda", "triton"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported implementation: {implementation}. "
|
||||||
|
"Only 'cuda' and 'triton' are supported."
|
||||||
|
)
|
||||||
|
if implementation == "triton" and kv_cache_layout == "HND":
|
||||||
|
return float("nan") # Triton does not support HND layout yet.
|
||||||
|
|
||||||
current_platform.seed_everything(42)
|
current_platform.seed_everything(42)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
@ -65,27 +78,49 @@ def run_benchmark(
|
|||||||
cache_layout=kv_cache_layout,
|
cache_layout=kv_cache_layout,
|
||||||
)
|
)
|
||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
# to free unused memory
|
||||||
|
del key_caches, value_caches
|
||||||
|
|
||||||
# compute per-kernel scaling factors for fp8 conversion (if used).
|
# compute per-kernel scaling factors for fp8 conversion (if used).
|
||||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||||
|
|
||||||
|
if implementation == "cuda":
|
||||||
|
function_under_test = lambda: ops.reshape_and_cache_flash(
|
||||||
|
key, # noqa: F821
|
||||||
|
value, # noqa: F821
|
||||||
|
key_cache, # noqa: F821
|
||||||
|
value_cache, # noqa: F821
|
||||||
|
slot_mapping, # noqa: F821
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
function_under_test = lambda: triton_reshape_and_cache_flash(
|
||||||
|
key, # noqa: F821
|
||||||
|
value, # noqa: F821
|
||||||
|
key_cache, # noqa: F821
|
||||||
|
value_cache, # noqa: F821
|
||||||
|
slot_mapping, # noqa: F821
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
if benchmark_mode == "cudagraph":
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
function_under_test()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
function_under_test = lambda: g.replay()
|
||||||
|
|
||||||
def run_cuda_benchmark(n_iters: int) -> float:
|
def run_cuda_benchmark(n_iters: int) -> float:
|
||||||
nonlocal key, value, key_cache, value_cache, slot_mapping
|
nonlocal key, value, key_cache, value_cache, slot_mapping
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
for _ in range(n_iters):
|
for _ in range(n_iters):
|
||||||
ops.reshape_and_cache_flash(
|
function_under_test()
|
||||||
key,
|
torch.cuda.synchronize()
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
slot_mapping,
|
|
||||||
kv_cache_dtype,
|
|
||||||
k_scale,
|
|
||||||
v_scale,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
return (end - start) / n_iters
|
return (end - start) / n_iters
|
||||||
|
|
||||||
@ -116,10 +151,16 @@ def main(args):
|
|||||||
kv_cache_dtype=args.kv_cache_dtype,
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
kv_cache_layout=layout,
|
kv_cache_layout=layout,
|
||||||
num_iters=args.iters,
|
num_iters=args.iters,
|
||||||
|
implementation=args.implementation,
|
||||||
|
benchmark_mode=args.mode,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
|
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Benchmark results for implementation {args.implementation}"
|
||||||
|
f" (measuring with {args.mode}):"
|
||||||
|
)
|
||||||
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
|
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
|
||||||
|
|
||||||
|
|
||||||
@ -151,6 +192,21 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--iters", type=int, default=100)
|
parser.add_argument("--iters", type=int, default=100)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--implementation",
|
||||||
|
type=str,
|
||||||
|
choices=["cuda", "triton"],
|
||||||
|
default="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
choices=["cudagraph", "no_graph"],
|
||||||
|
default="cudagraph",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@ -39,6 +39,8 @@ CUDA_DEVICES = [
|
|||||||
# We assume fp8 is always enabled for testing.
|
# We assume fp8 is always enabled for testing.
|
||||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||||
|
|
||||||
|
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||||
@ -223,6 +225,7 @@ def test_reshape_and_cache(
|
|||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||||
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
|
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
|
||||||
|
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_reshape_and_cache_flash(
|
def test_reshape_and_cache_flash(
|
||||||
kv_cache_factory_flashinfer,
|
kv_cache_factory_flashinfer,
|
||||||
@ -236,9 +239,13 @@ def test_reshape_and_cache_flash(
|
|||||||
device: str,
|
device: str,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_cache_layout: str,
|
kv_cache_layout: str,
|
||||||
|
implementation: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
assert implementation in ["cuda", "triton"]
|
||||||
|
if implementation == "triton" and kv_cache_layout == "HND":
|
||||||
|
pytest.skip("Triton implementation only supports NHD layout.")
|
||||||
|
|
||||||
# fp8 conversion requires continugous memory buffer. Reduce the number of
|
# fp8 conversion requires continugous memory buffer. Reduce the number of
|
||||||
# blocks and tokens to consume less memory.
|
# blocks and tokens to consume less memory.
|
||||||
@ -298,12 +305,20 @@ def test_reshape_and_cache_flash(
|
|||||||
cloned_key_cache = key_cache_compact.clone()
|
cloned_key_cache = key_cache_compact.clone()
|
||||||
cloned_value_cache = value_cache_compact.clone()
|
cloned_value_cache = value_cache_compact.clone()
|
||||||
# Call the reshape_and_cache kernel.
|
# Call the reshape_and_cache kernel.
|
||||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
if implementation == "cuda":
|
||||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||||
k_scale, v_scale),
|
(key, value, key_cache, value_cache, slot_mapping,
|
||||||
cond=(head_size == HEAD_SIZES[0]))
|
kv_cache_dtype, k_scale, v_scale),
|
||||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
cond=(head_size == HEAD_SIZES[0]))
|
||||||
slot_mapping, kv_cache_dtype, k_scale, v_scale)
|
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||||
|
slot_mapping, kv_cache_dtype, k_scale,
|
||||||
|
v_scale)
|
||||||
|
elif implementation == "triton":
|
||||||
|
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||||
|
triton_reshape_and_cache_flash)
|
||||||
|
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||||
|
slot_mapping, kv_cache_dtype, k_scale,
|
||||||
|
v_scale)
|
||||||
key_cache_compact = permute_and_compact(key_cache)
|
key_cache_compact = permute_and_compact(key_cache)
|
||||||
value_cache_compact = permute_and_compact(value_cache)
|
value_cache_compact = permute_and_compact(value_cache)
|
||||||
|
|
||||||
|
|||||||
176
vllm/attention/ops/triton_reshape_and_cache_flash.py
Normal file
176
vllm/attention/ops/triton_reshape_and_cache_flash.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def reshape_and_cache_kernel_flash(
|
||||||
|
key_ptr, # [num_tokens, num_heads, head_size]
|
||||||
|
value_ptr, # [num_tokens, num_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||||
|
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||||
|
slot_mapping_ptr, # [num_tokens]
|
||||||
|
k_scale, # float32
|
||||||
|
v_scale, # float32
|
||||||
|
# strides
|
||||||
|
key_stride: tl.int64,
|
||||||
|
value_stride: tl.int64,
|
||||||
|
block_stride: tl.int64,
|
||||||
|
page_stride: tl.int64,
|
||||||
|
num_heads: tl.constexpr,
|
||||||
|
head_size: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
# FP8 flags
|
||||||
|
FP8_KV_CACHE: tl.constexpr,
|
||||||
|
# tune parameters
|
||||||
|
TILE_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
|
||||||
|
token_idx = tl.program_id(axis=0)
|
||||||
|
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
||||||
|
if slot_idx < 0:
|
||||||
|
# Padding token that should be ignored.
|
||||||
|
return
|
||||||
|
|
||||||
|
tile_i = tl.program_id(axis=1)
|
||||||
|
tile_offs = tl.arange(0, TILE_SIZE)
|
||||||
|
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||||
|
|
||||||
|
block_idx = slot_idx // block_size
|
||||||
|
block_offset = slot_idx % block_size
|
||||||
|
|
||||||
|
src_key_idx = token_idx * key_stride
|
||||||
|
src_value_idx = token_idx * value_stride
|
||||||
|
|
||||||
|
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
||||||
|
|
||||||
|
# [TILE_SIZE]
|
||||||
|
key_load = tl.load(key_ptr + src_key_idx + tile_pos,
|
||||||
|
mask=tile_pos < (num_heads * head_size))
|
||||||
|
if FP8_KV_CACHE:
|
||||||
|
if key_load.dtype.is_fp8():
|
||||||
|
key_tile = key_load
|
||||||
|
else:
|
||||||
|
# tl.store will do the correct implicit cast to fp8,
|
||||||
|
# based on the key_cache_ptr.dtype.element_ty
|
||||||
|
key_tile = key_load / tl.load(k_scale)
|
||||||
|
else:
|
||||||
|
key_tile = key_load
|
||||||
|
|
||||||
|
# [TILE_SIZE]
|
||||||
|
value_load = tl.load(value_ptr + src_value_idx + tile_pos,
|
||||||
|
mask=tile_pos < (num_heads * head_size))
|
||||||
|
if FP8_KV_CACHE:
|
||||||
|
if value_load.dtype.is_fp8():
|
||||||
|
value_tile = value_load
|
||||||
|
else:
|
||||||
|
# tl.store will do the correct implicit cast to fp8,
|
||||||
|
# based on the value_cache_ptr.dtype.element_ty
|
||||||
|
value_tile = value_load / tl.load(v_scale)
|
||||||
|
else:
|
||||||
|
value_tile = value_load
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
key_cache_ptr + tgt_idx + tile_pos,
|
||||||
|
key_tile,
|
||||||
|
mask=tile_pos < (num_heads * head_size),
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
value_cache_ptr + tgt_idx + tile_pos,
|
||||||
|
value_tile,
|
||||||
|
mask=tile_pos < (num_heads * head_size),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def triton_reshape_and_cache_flash(
|
||||||
|
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
|
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
|
# [num_blocks, block_size, num_heads, head_size]
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
# [num_blocks, block_size, num_heads, head_size]
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor, # [num_tokens]
|
||||||
|
kv_cache_dtype: str, # "auto", "fp8"
|
||||||
|
k_scale: torch.Tensor, # float32
|
||||||
|
v_scale: torch.Tensor, # float32
|
||||||
|
):
|
||||||
|
num_tokens = key.shape[0]
|
||||||
|
num_heads = key.shape[1]
|
||||||
|
head_size = key.shape[2]
|
||||||
|
block_size = key_cache.shape[1]
|
||||||
|
n = num_heads * head_size
|
||||||
|
|
||||||
|
key_stride = key.stride()[0]
|
||||||
|
value_stride = value.stride()[0]
|
||||||
|
block_stride = key_cache.stride()[0]
|
||||||
|
page_stride = key_cache.stride()[1]
|
||||||
|
|
||||||
|
head_stride = key_cache.stride()[2]
|
||||||
|
assert head_stride == head_size, "only continous heads are supported"
|
||||||
|
|
||||||
|
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \
|
||||||
|
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||||
|
kv_cache_torch_dtype = current_platform.fp8_dtype() if \
|
||||||
|
kv_cache_dtype.startswith("fp8") else key_cache.dtype
|
||||||
|
|
||||||
|
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith(
|
||||||
|
"fp8"):
|
||||||
|
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||||
|
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||||
|
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||||
|
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||||
|
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\
|
||||||
|
"uint8 is not supported by triton reshape_and_cache_flash"
|
||||||
|
|
||||||
|
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||||
|
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||||
|
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8,
|
||||||
|
torch.float8_e4m3fnuz], \
|
||||||
|
"unsupported dtype of KV cache tensor, got "\
|
||||||
|
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \
|
||||||
|
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||||
|
|
||||||
|
# heuristics instead of autotuning
|
||||||
|
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
||||||
|
if torch.version.hip:
|
||||||
|
num_stages = 4
|
||||||
|
num_warps = 8
|
||||||
|
else: # cuda
|
||||||
|
num_stages = 10
|
||||||
|
num_warps = 16
|
||||||
|
if torch.cuda.get_device_capability(key.device)[0] < 9:
|
||||||
|
TILE_SIZE = min(512, TILE_SIZE)
|
||||||
|
|
||||||
|
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
||||||
|
# using cudagraphs
|
||||||
|
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
|
||||||
|
|
||||||
|
reshape_and_cache_kernel_flash[grid](
|
||||||
|
key_ptr=key,
|
||||||
|
value_ptr=value,
|
||||||
|
key_cache_ptr=key_cache,
|
||||||
|
value_cache_ptr=value_cache,
|
||||||
|
slot_mapping_ptr=slot_mapping,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
# strides
|
||||||
|
key_stride=key_stride,
|
||||||
|
value_stride=value_stride,
|
||||||
|
block_stride=block_stride,
|
||||||
|
page_stride=page_stride,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
block_size=block_size,
|
||||||
|
# FP8 flags
|
||||||
|
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||||
|
# autotune parameters
|
||||||
|
TILE_SIZE=TILE_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
@ -8,6 +8,8 @@ import torch
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||||
|
triton_reshape_and_cache_flash)
|
||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -291,7 +293,13 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
if self.kv_sharing_target_layer_name is None:
|
if self.kv_sharing_target_layer_name is None:
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
ops.reshape_and_cache_flash(
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
|
# triton kernel does not support uint8 kv_cache
|
||||||
|
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||||
|
# are not supported)
|
||||||
|
triton_reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
key_cache,
|
||||||
@ -303,8 +311,9 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
if key_cache.dtype != self.fp8_dtype:
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
num_tokens, num_heads, head_size = query.shape
|
num_tokens, num_heads, head_size = query.shape
|
||||||
assert layer._q_scale_float == 1.0, \
|
assert layer._q_scale_float == 1.0, \
|
||||||
"A non 1.0 q_scale is not currently supported."
|
"A non 1.0 q_scale is not currently supported."
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user