mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[Model][AMD] ROCm support for 256 head dims for Gemma (#3972)
This commit is contained in:
parent
bd3c144e0b
commit
8b317c6dd0
@ -677,8 +677,7 @@ def check_args(
|
|||||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||||
# TODO: Change assert if we support qkl f8 and v f16
|
# TODO: Change assert if we support qkl f8 and v f16
|
||||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||||
# TODO: Fix assert to check head size <=256 once supported
|
assert head_size <= 256
|
||||||
assert head_size <= 128
|
|
||||||
assert o.shape == q.shape
|
assert o.shape == q.shape
|
||||||
assert (nheads_q % nheads_k) == 0
|
assert (nheads_q % nheads_k) == 0
|
||||||
|
|
||||||
@ -729,7 +728,7 @@ class _attention(torch.autograd.Function):
|
|||||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||||
|
|
||||||
# Get closest power of 2 over or equal to 32.
|
# Get closest power of 2 over or equal to 32.
|
||||||
unpadded_head_dims = {32, 64, 128}
|
unpadded_head_dims = {32, 64, 128, 256}
|
||||||
if head_size not in unpadded_head_dims:
|
if head_size not in unpadded_head_dims:
|
||||||
padded_d_model = None
|
padded_d_model = None
|
||||||
for i in unpadded_head_dims:
|
for i in unpadded_head_dims:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user