mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-07 03:39:11 +08:00
[ROCm][Bugfix] Fix RuntimeError in MMEncoderAttention by replacing .view() with .reshape() (#31203)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
parent
f790068600
commit
bfa2c0bbb9
@ -19,7 +19,7 @@ def pytest_collection_modifyitems(config, items):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
|
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
|
||||||
# accuracy issues
|
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
|
||||||
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
|
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
|
||||||
torch.backends.cuda.enable_flash_sdp(False)
|
torch.backends.cuda.enable_flash_sdp(False)
|
||||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||||
|
|||||||
@ -136,7 +136,7 @@ class MMEncoderAttention(CustomOp):
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
)
|
)
|
||||||
if is_reshaped:
|
if is_reshaped:
|
||||||
output = output.view(bsz, q_len, -1)
|
output = output.reshape(bsz, q_len, -1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _forward_fa(
|
def _forward_fa(
|
||||||
@ -174,7 +174,7 @@ class MMEncoderAttention(CustomOp):
|
|||||||
fa_version=self._fa_version,
|
fa_version=self._fa_version,
|
||||||
)
|
)
|
||||||
if is_reshaped:
|
if is_reshaped:
|
||||||
output = output.view(bsz, q_len, -1)
|
output = output.reshape(bsz, q_len, -1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user