mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 01:37:53 +08:00
Fix the torch version parsing logic (#15857)
This commit is contained in:
parent
8661c0241d
commit
7678fcd5b6
@ -2,7 +2,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
import importlib.metadata
|
|
||||||
import os
|
import os
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
@ -11,9 +10,9 @@ from unittest.mock import patch
|
|||||||
import torch
|
import torch
|
||||||
import torch._inductor.compile_fx
|
import torch._inductor.compile_fx
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
from packaging.version import Version
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
|
||||||
class CompilerInterface:
|
class CompilerInterface:
|
||||||
@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
manually setting up internal contexts. But we also rely on non-public
|
manually setting up internal contexts. But we also rely on non-public
|
||||||
APIs which might not provide these guarantees.
|
APIs which might not provide these guarantees.
|
||||||
"""
|
"""
|
||||||
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
|
if is_torch_equal_or_newer("2.6"):
|
||||||
import torch._dynamo.utils
|
import torch._dynamo.utils
|
||||||
return torch._dynamo.utils.get_metrics_context()
|
return torch._dynamo.utils.get_metrics_context()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,17 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import importlib.metadata
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import types
|
import types
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Callable, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging.version import Version
|
|
||||||
from torch import fx
|
from torch import fx
|
||||||
|
|
||||||
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
if is_torch_equal_or_newer("2.6"):
|
||||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||||
else:
|
else:
|
||||||
# CustomGraphPass is not present in 2.5 or lower, import our version
|
# CustomGraphPass is not present in 2.5 or lower, import our version
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import ast
|
|||||||
import copy
|
import copy
|
||||||
import enum
|
import enum
|
||||||
import hashlib
|
import hashlib
|
||||||
import importlib.metadata
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
@ -18,7 +17,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
|||||||
Optional, Protocol, Union)
|
Optional, Protocol, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging.version import Version
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
from torch.distributed import ProcessGroup, ReduceOp
|
from torch.distributed import ProcessGroup, ReduceOp
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -40,8 +38,8 @@ from vllm.transformers_utils.config import (
|
|||||||
from vllm.transformers_utils.s3_utils import S3Model
|
from vllm.transformers_utils.s3_utils import S3Model
|
||||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||||
get_cpu_memory, get_open_port, random_uuid,
|
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
|
||||||
resolve_obj_by_qualname)
|
random_uuid, resolve_obj_by_qualname)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
@ -3285,7 +3283,7 @@ class CompilationConfig(BaseModel):
|
|||||||
# and it is not yet a priority. RFC here:
|
# and it is not yet a priority. RFC here:
|
||||||
# https://github.com/vllm-project/vllm/issues/14703
|
# https://github.com/vllm-project/vllm/issues/14703
|
||||||
|
|
||||||
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
|
if is_torch_equal_or_newer("2.6"):
|
||||||
KEY = 'enable_auto_functionalized_v2'
|
KEY = 'enable_auto_functionalized_v2'
|
||||||
if KEY not in self.inductor_compile_config:
|
if KEY not in self.inductor_compile_config:
|
||||||
self.inductor_compile_config[KEY] = False
|
self.inductor_compile_config[KEY] = False
|
||||||
|
|||||||
@ -53,6 +53,7 @@ import torch.types
|
|||||||
import yaml
|
import yaml
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
from packaging import version
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
|
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
|
||||||
@ -2580,3 +2581,20 @@ def sha256(input) -> int:
|
|||||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
|
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
|
||||||
byteorder="big")
|
byteorder="big")
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_equal_or_newer(target: str) -> bool:
|
||||||
|
"""Check if the installed torch version is >= the target version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: a version string, like "2.6.0".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the condition meets.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
torch_version = version.parse(str(torch.__version__))
|
||||||
|
return torch_version >= version.parse(target)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
|
||||||
|
return Version(importlib.metadata.version('torch')) >= Version(target)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user