[BUG FIX][NON-CUDA]quick fix to avoid call cudagraph_unsafe in attention (#25298)

Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
This commit is contained in:
Chendi.Xue 2025-09-19 23:41:23 -05:00 committed by GitHub
parent b7f186bbb3
commit 6c5f82e5aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -29,6 +29,10 @@ from vllm.utils import GiB_bytes, direct_register_custom_op
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
try:
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment]
def check_xformers_availability():
@ -577,7 +581,7 @@ direct_register_custom_op(
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch._C.Tag.cudagraph_unsafe, ),
tags=tag_cudagraph_unsafe,
)
@ -628,5 +632,5 @@ direct_register_custom_op(
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch._C.Tag.cudagraph_unsafe, ),
tags=tag_cudagraph_unsafe,
)