Fix the torch version parsing logic (#15857)

This commit is contained in:
Lu Fang 2025-04-10 07:37:47 -07:00 committed by GitHub
parent 8661c0241d
commit 7678fcd5b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 11 deletions

View File

@ -2,7 +2,6 @@
import contextlib import contextlib
import copy import copy
import hashlib import hashlib
import importlib.metadata
import os import os
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
@ -11,9 +10,9 @@ from unittest.mock import patch
import torch import torch
import torch._inductor.compile_fx import torch._inductor.compile_fx
import torch.fx as fx import torch.fx as fx
from packaging.version import Version
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer
class CompilerInterface: class CompilerInterface:
@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface):
manually setting up internal contexts. But we also rely on non-public manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees. APIs which might not provide these guarantees.
""" """
if Version(importlib.metadata.version('torch')) >= Version("2.6"): if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context() return torch._dynamo.utils.get_metrics_context()
else: else:

View File

@ -1,17 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import hashlib import hashlib
import importlib.metadata
import inspect import inspect
import json import json
import types import types
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Dict, Optional, Union
import torch import torch
from packaging.version import Version
from torch import fx from torch import fx
if Version(importlib.metadata.version('torch')) >= Version("2.6"): from vllm.utils import is_torch_equal_or_newer
if is_torch_equal_or_newer("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass from torch._inductor.custom_graph_pass import CustomGraphPass
else: else:
# CustomGraphPass is not present in 2.5 or lower, import our version # CustomGraphPass is not present in 2.5 or lower, import our version

View File

@ -4,7 +4,6 @@ import ast
import copy import copy
import enum import enum
import hashlib import hashlib
import importlib.metadata
import json import json
import sys import sys
import warnings import warnings
@ -18,7 +17,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, Union) Optional, Protocol, Union)
import torch import torch
from packaging.version import Version
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -40,8 +38,8 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, random_uuid, get_cpu_memory, get_open_port, is_torch_equal_or_newer,
resolve_obj_by_qualname) random_uuid, resolve_obj_by_qualname)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
@ -3285,7 +3283,7 @@ class CompilationConfig(BaseModel):
# and it is not yet a priority. RFC here: # and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703 # https://github.com/vllm-project/vllm/issues/14703
if Version(importlib.metadata.version('torch')) >= Version("2.6"): if is_torch_equal_or_newer("2.6"):
KEY = 'enable_auto_functionalized_v2' KEY = 'enable_auto_functionalized_v2'
if KEY not in self.inductor_compile_config: if KEY not in self.inductor_compile_config:
self.inductor_compile_config[KEY] = False self.inductor_compile_config[KEY] = False

View File

@ -53,6 +53,7 @@ import torch.types
import yaml import yaml
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from packaging import version
from packaging.version import Version from packaging.version import Version
from torch.library import Library from torch.library import Library
from typing_extensions import Never, ParamSpec, TypeIs, assert_never from typing_extensions import Never, ParamSpec, TypeIs, assert_never
@ -2580,3 +2581,20 @@ def sha256(input) -> int:
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(), return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big") byteorder="big")
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try:
torch_version = version.parse(str(torch.__version__))
return torch_version >= version.parse(target)
except Exception:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return Version(importlib.metadata.version('torch')) >= Version(target)