set assume_32bit_indexing and pass unbacked hints (#30459)

Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
Laith Sakka 2025-12-13 18:36:53 +03:00 committed by GitHub
parent 39cefbdf17
commit 763963aa73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import supports_dynamo
from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
from .monitor import start_monitoring_torch_compile
@ -316,7 +316,13 @@ def _support_torch_compile(
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
def mark_dynamic(arg, dims):
if type == DynamicShapesType.UNBACKED:
torch._dynamo.decorators.mark_unbacked(arg, dims)
if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
else:
torch._dynamo.mark_dynamic(arg, dims)
@ -350,7 +356,13 @@ def _support_torch_compile(
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.decorators.mark_unbacked(arg, dims)
if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
@ -488,6 +500,12 @@ def _support_torch_compile(
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
fx_config_patches["backed_size_oblivious"] = True
# Prepare inductor config patches
# assume_32bit_indexing is only available in torch 2.10.0.dev+
inductor_config_patches = {}
if is_torch_equal_or_newer("2.10.0.dev"):
inductor_config_patches["assume_32bit_indexing"] = True
with (
patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call
@ -496,6 +514,7 @@ def _support_torch_compile(
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(),
torch._inductor.config.patch(**inductor_config_patches),
):
if envs.VLLM_USE_AOT_COMPILE:
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)