[ROCm][Bugfix] Fix RuntimeError in MMEncoderAttention by replacing .view() with .reshape() (#31203)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas 2025-12-23 15:48:01 -06:00 committed by GitHub
parent f790068600
commit bfa2c0bbb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -19,7 +19,7 @@ def pytest_collection_modifyitems(config, items):
return
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)

View File

@ -136,7 +136,7 @@ class MMEncoderAttention(CustomOp):
cu_seqlens=cu_seqlens,
)
if is_reshaped:
output = output.view(bsz, q_len, -1)
output = output.reshape(bsz, q_len, -1)
return output
def _forward_fa(
@ -174,7 +174,7 @@ class MMEncoderAttention(CustomOp):
fa_version=self._fa_version,
)
if is_reshaped:
output = output.view(bsz, q_len, -1)
output = output.reshape(bsz, q_len, -1)
return output
def forward_native(