From 0585a9e73c072a8cbb1a64bea3c26dd0d2dde402 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 27 Aug 2025 20:03:05 +0100 Subject: [PATCH] Disable `torch.compile` for dynamic rope models in Transformers backend (#23738) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 25 +++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index fc242d1adafd0..dffc347a73668 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -88,6 +88,23 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): logger.debug("%s: %s -> %s", name, old_module, new_module) +def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: + """ + Callable to be passed to `@support_torch_compile`'s `enable_if` argument. + + Defaults to `True` but is disabled in the following situations: + + - The model uses dynamic rope scaling. + """ + enable = True + text_config = vllm_config.model_config.hf_config.get_text_config() + # Dynamic rope scaling is not compatible with torch.compile + rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {} + if rope_scaling.get("rope_type") == "dynamic": + enable = False + return enable + + def replace_linear_class( linear: nn.Linear, style: Literal["colwise", "rowwise"], quant_config: QuantizationConfig @@ -641,7 +658,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) -@support_torch_compile +@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersModel(TransformersBase): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -653,7 +670,7 @@ class TransformersModel(TransformersBase): }) -@support_torch_compile +@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(TransformersBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -709,12 +726,14 @@ def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor: info=MultiModalProcessingInfo, dummy_inputs=MultiModalDummyInputsBuilder) @support_torch_compile( + # set `positions` to last dim to support Qwen-mrope dynamic_arg_dims={ "input_ids": 0, "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) # set `positions` to last dim to support Qwen-mrope + }, + enable_if=can_enable_torch_compile) class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): # Backwards compatibility for prev released models. State dicts back then # had different formats and cannot be loaded with `AutoModel` mapping as is