mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 23:05:59 +08:00
set assume_32bit_indexing and pass unbacked hints (#30459)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
parent
39cefbdf17
commit
763963aa73
@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import supports_dynamo
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
|
||||
@ -316,7 +316,13 @@ def _support_torch_compile(
|
||||
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
|
||||
def mark_dynamic(arg, dims):
|
||||
if type == DynamicShapesType.UNBACKED:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
)
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
else:
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
|
||||
@ -350,7 +356,13 @@ def _support_torch_compile(
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
)
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
@ -488,6 +500,12 @@ def _support_torch_compile(
|
||||
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
|
||||
fx_config_patches["backed_size_oblivious"] = True
|
||||
|
||||
# Prepare inductor config patches
|
||||
# assume_32bit_indexing is only available in torch 2.10.0.dev+
|
||||
inductor_config_patches = {}
|
||||
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||
inductor_config_patches["assume_32bit_indexing"] = True
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
@ -496,6 +514,7 @@ def _support_torch_compile(
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
torch.fx.experimental._config.patch(**fx_config_patches),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
torch._inductor.config.patch(**inductor_config_patches),
|
||||
):
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user