mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 12:04:27 +08:00
Disable torch.compile for dynamic rope models in Transformers backend (#23738)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
3c0ef769ba
commit
0585a9e73c
@ -88,6 +88,23 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
|||||||
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
||||||
|
|
||||||
|
|
||||||
|
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
|
||||||
|
"""
|
||||||
|
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
|
||||||
|
|
||||||
|
Defaults to `True` but is disabled in the following situations:
|
||||||
|
|
||||||
|
- The model uses dynamic rope scaling.
|
||||||
|
"""
|
||||||
|
enable = True
|
||||||
|
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||||
|
# Dynamic rope scaling is not compatible with torch.compile
|
||||||
|
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
|
||||||
|
if rope_scaling.get("rope_type") == "dynamic":
|
||||||
|
enable = False
|
||||||
|
return enable
|
||||||
|
|
||||||
|
|
||||||
def replace_linear_class(
|
def replace_linear_class(
|
||||||
linear: nn.Linear, style: Literal["colwise", "rowwise"],
|
linear: nn.Linear, style: Literal["colwise", "rowwise"],
|
||||||
quant_config: QuantizationConfig
|
quant_config: QuantizationConfig
|
||||||
@ -641,7 +658,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
class TransformersModel(TransformersBase):
|
class TransformersModel(TransformersBase):
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
@ -653,7 +670,7 @@ class TransformersModel(TransformersBase):
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
class TransformersForCausalLM(TransformersBase):
|
class TransformersForCausalLM(TransformersBase):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -709,12 +726,14 @@ def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
|
|||||||
info=MultiModalProcessingInfo,
|
info=MultiModalProcessingInfo,
|
||||||
dummy_inputs=MultiModalDummyInputsBuilder)
|
dummy_inputs=MultiModalDummyInputsBuilder)
|
||||||
@support_torch_compile(
|
@support_torch_compile(
|
||||||
|
# set `positions` to last dim to support Qwen-mrope
|
||||||
dynamic_arg_dims={
|
dynamic_arg_dims={
|
||||||
"input_ids": 0,
|
"input_ids": 0,
|
||||||
"positions": -1,
|
"positions": -1,
|
||||||
"intermediate_tensors": 0,
|
"intermediate_tensors": 0,
|
||||||
"inputs_embeds": 0,
|
"inputs_embeds": 0,
|
||||||
}) # set `positions` to last dim to support Qwen-mrope
|
},
|
||||||
|
enable_if=can_enable_torch_compile)
|
||||||
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
||||||
# Backwards compatibility for prev released models. State dicts back then
|
# Backwards compatibility for prev released models. State dicts back then
|
||||||
# had different formats and cannot be loaded with `AutoModel` mapping as is
|
# had different formats and cannot be loaded with `AutoModel` mapping as is
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user