[Chore] Clean up pytorch helper functions in vllm.utils (#26908)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-10-19 00:48:22 +08:00 committed by GitHub
parent 5c2acb270a
commit 6ac5e06f7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
119 changed files with 772 additions and 714 deletions

View File

@ -10,7 +10,8 @@ import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
def with_triton_mode(fn): def with_triton_mode(fn):

View File

@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
batch_size_range = [1, 16, 32, 64, 128] batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]

View File

@ -7,7 +7,8 @@ import torch
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode() @torch.inference_mode()

View File

@ -9,9 +9,9 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import ( from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random, create_kv_caches_with_random,
) )

View File

@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode() @torch.inference_mode()

View File

@ -9,9 +9,9 @@ from tabulate import tabulate
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import ( from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random, create_kv_caches_with_random,
) )

View File

@ -12,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import ( from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random_flash, create_kv_caches_with_random_flash,
) )

View File

@ -11,7 +11,7 @@ from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig from vllm.config import CompilationConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
@contextlib.contextmanager @contextlib.contextmanager

View File

@ -20,7 +20,7 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401 from .. import silly_attention # noqa: F401

View File

@ -19,7 +19,7 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter from ..silly_attention import get_global_counter, reset_global_counter

View File

@ -27,7 +27,7 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401 from .. import silly_attention # noqa: F401

View File

@ -8,7 +8,7 @@ Centralizes custom operation definitions to avoid duplicate registrations.
import torch import torch
from torch.library import Library from torch.library import Library
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
# Shared library for all compilation test operations # Shared library for all compilation test operations
# Using "silly" namespace to match existing test expectations # Using "silly" namespace to match existing test expectations

View File

@ -15,7 +15,7 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
def reference_fn(x: torch.Tensor): def reference_fn(x: torch.Tensor):

View File

@ -5,7 +5,7 @@ import dataclasses
import pytest import pytest
from vllm.config import CompilationMode from vllm.config import CompilationMode
from vllm.utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from ..utils import compare_all_settings from ..utils import compare_all_settings

View File

@ -8,7 +8,7 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode from vllm.config.compilation import CompilationMode
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
def test_version(): def test_version():

View File

@ -15,7 +15,7 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401 from . import silly_attention # noqa: F401

View File

@ -12,7 +12,7 @@ from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test

View File

@ -15,8 +15,8 @@ from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import flat_product, multi_gpu_test from ..utils import flat_product, multi_gpu_test

View File

@ -60,8 +60,8 @@ from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import set_default_torch_num_threads
from vllm.utils.collections import is_list_of from vllm.utils.collections import is_list_of
from vllm.utils.torch_utils import set_default_torch_num_threads
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -18,7 +18,7 @@ import pytest
from vllm.config.compilation import CompilationMode from vllm.config.compilation import CompilationMode
from vllm.config.model import RunnerOption from vllm.config.model import RunnerOption
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test from ..utils import compare_two_settings, create_new_process_for_each_test

View File

@ -11,10 +11,10 @@ import vllm.envs as envs
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import ( from vllm.utils import (
cuda_device_count_stateless,
get_open_port, get_open_port,
update_environment_variables, update_environment_variables,
) )
from vllm.utils.torch_utils import cuda_device_count_stateless
from ..utils import multi_gpu_test from ..utils import multi_gpu_test

View File

@ -3,7 +3,10 @@
import pytest import pytest
from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash from vllm.utils.torch_utils import (
create_kv_caches_with_random,
create_kv_caches_with_random_flash,
)
@pytest.fixture() @pytest.fixture()

View File

@ -15,7 +15,7 @@ from tests.kernels.utils import make_alibi_bias
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS = [64] NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 64] NUM_QUERIES_PER_KV = [1, 64]

View File

@ -3,7 +3,8 @@
import pytest import pytest
import torch import torch
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available from vllm.utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]

View File

@ -13,8 +13,9 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import cuda_device_count_stateless
from .modular_kernel_tools.common import ( from .modular_kernel_tools.common import (
Config, Config,

View File

@ -22,8 +22,8 @@ from vllm.utils import (
STR_BACKEND_ENV_VAR, STR_BACKEND_ENV_VAR,
STR_FLASH_ATTN_VAL, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad,
) )
from vllm.utils.torch_utils import make_tensor_with_pad
# For now, disable "test_aot_dispatch_dynamic" since there are some # For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4. # bugs related to this test in PyTorch 2.4.

View File

@ -7,7 +7,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import ImageTestAssets from ....conftest import ImageTestAssets

View File

@ -9,7 +9,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.radio import RadioModel
from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import ImageTestAssets from ....conftest import ImageTestAssets

View File

@ -26,7 +26,6 @@ from vllm.distributed import (
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
SupportsMultiModal, SupportsMultiModal,
supports_multimodal, supports_multimodal,
@ -36,6 +35,7 @@ from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingC
from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.collections import is_list_of from vllm.utils.collections import is_list_of
from vllm.utils.torch_utils import set_default_torch_dtype
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
from ...utils import dummy_hf_overrides from ...utils import dummy_hf_overrides

View File

@ -46,10 +46,10 @@ from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import ( from vllm.utils import (
FlexibleArgumentParser, FlexibleArgumentParser,
cuda_device_count_stateless,
get_open_port, get_open_port,
) )
from vllm.utils.mem_constants import GB_bytes from vllm.utils.mem_constants import GB_bytes
from vllm.utils.torch_utils import cuda_device_count_stateless
if current_platform.is_rocm(): if current_platform.is_rocm():
from amdsmi import ( from amdsmi import (

View File

@ -24,11 +24,8 @@ from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
from vllm.utils import ( from vllm.utils import (
FlexibleArgumentParser, FlexibleArgumentParser,
bind_kv_cache, bind_kv_cache,
common_broadcastable_dtype,
current_stream,
get_open_port, get_open_port,
get_tcp_uri, get_tcp_uri,
is_lossless_cast,
join_host_port, join_host_port,
make_zmq_path, make_zmq_path,
make_zmq_socket, make_zmq_socket,
@ -37,6 +34,11 @@ from vllm.utils import (
split_zmq_path, split_zmq_path,
unique_filepath, unique_filepath,
) )
from vllm.utils.torch_utils import (
common_broadcastable_dtype,
current_stream,
is_lossless_cast,
)
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
from ..utils import create_new_process_for_each_test, flat_product from ..utils import create_new_process_for_each_test, flat_product
@ -408,7 +410,7 @@ def test_bind_kv_cache_non_attention():
def test_bind_kv_cache_pp(): def test_bind_kv_cache_pp():
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2): with patch("vllm.utils.torch_utils.cuda_device_count_stateless", lambda: 2):
# this test runs with 1 GPU, but we simulate 2 GPUs # this test runs with 1 GPU, but we simulate 2 GPUs
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2)) cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg): with set_current_vllm_config(cfg):

View File

@ -18,7 +18,8 @@ from tests.v1.attention.utils import (
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, CommonAttentionMetadata,
set_kv_cache_layout, set_kv_cache_layout,

View File

@ -22,7 +22,8 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.config.vllm import set_current_vllm_config from vllm.config.vllm import set_current_vllm_config
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec

View File

@ -15,7 +15,7 @@ from vllm.inputs import PromptType
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import ( from vllm.v1.metrics.loggers import (
AggregatedLoggingStatLogger, AggregatedLoggingStatLogger,

View File

@ -12,7 +12,7 @@ from transformers import AutoTokenizer
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor, UniProcExecutor from vllm.v1.executor.abstract import Executor, UniProcExecutor

View File

@ -21,7 +21,7 @@ from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublis
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient

View File

@ -7,7 +7,8 @@ import torch
from tests.v1.sample.utils import create_allowed_token_ids from tests.v1.sample.utils import create_allowed_token_ids
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler

View File

@ -9,7 +9,7 @@ import regex as re
import torch import torch
from vllm import CompletionOutput from vllm import CompletionOutput
from vllm.utils import make_tensor_with_pad from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata

View File

@ -12,7 +12,7 @@ from tests.v1.shutdown.utils import (
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"] MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]

View File

@ -14,7 +14,7 @@ from tests.v1.shutdown.utils import (
from vllm import LLM, AsyncEngineArgs, SamplingParams from vllm import LLM, AsyncEngineArgs, SamplingParams
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError

View File

@ -13,7 +13,7 @@ from vllm import LLM
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"] MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]

View File

@ -10,7 +10,8 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata

View File

@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import ( from vllm.utils.torch_utils import (
direct_register_custom_op, direct_register_custom_op,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
) )

View File

@ -5,7 +5,7 @@
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
def get_aiter_mla_metadata( def get_aiter_mla_metadata(

View File

@ -24,8 +24,8 @@ from vllm.compilation.partition_rules import (
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .caching import VllmSerializableFunction from .caching import VllmSerializableFunction
from .compiler_interface import ( from .compiler_interface import (

View File

@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm

View File

@ -16,7 +16,7 @@ import torch.fx as fx
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
class CompilerInterface: class CompilerInterface:

View File

@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import set_graph_poo
from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import weak_ref_tensors from vllm.utils.torch_utils import weak_ref_tensors
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -21,8 +21,8 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import supports_dynamo
from .monitor import start_monitoring_torch_compile from .monitor import start_monitoring_torch_compile

View File

@ -14,7 +14,7 @@ import torch
from torch import fx from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
if is_torch_equal_or_newer("2.6"): if is_torch_equal_or_newer("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass from torch._inductor.custom_graph_pass import CustomGraphPass

View File

@ -16,8 +16,8 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig

View File

@ -41,8 +41,9 @@ from vllm.transformers_utils.config import (
) )
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import LayerBlockType, common_broadcastable_dtype from vllm.utils import LayerBlockType
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig from transformers import PretrainedConfig

View File

@ -18,7 +18,8 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, get_open_ports_list from vllm.utils import get_open_ports_list
from vllm.utils.torch_utils import cuda_device_count_stateless
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.runtime_env import RuntimeEnv from ray.runtime_env import RuntimeEnv

View File

@ -22,7 +22,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.utils import cuda_device_count_stateless, update_environment_variables from vllm.utils import update_environment_variables
from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
try: try:
ops.meta_size() ops.meta_size()

View File

@ -19,7 +19,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
) )
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import current_stream from vllm.utils.torch_utils import current_stream
logger = init_logger(__name__) logger = init_logger(__name__)
@ -30,7 +30,7 @@ def register_nccl_symmetric_ops(pynccl_comm):
from vllm.distributed.device_communicators.pynccl_allocator import ( from vllm.distributed.device_communicators.pynccl_allocator import (
nccl_symm_mem_context, nccl_symm_mem_context,
) )
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
global _NCCL_SYMM_OPS_REGISTERED global _NCCL_SYMM_OPS_REGISTERED
if _NCCL_SYMM_OPS_REGISTERED: if _NCCL_SYMM_OPS_REGISTERED:

View File

@ -13,7 +13,7 @@ from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -14,7 +14,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
) )
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import current_stream from vllm.utils.torch_utils import current_stream
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -25,7 +25,8 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool, TensorMemoryPool,
) )
from vllm.utils import current_stream, get_ip from vllm.utils import get_ip
from vllm.utils.torch_utils import current_stream
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -50,11 +50,13 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import ( from vllm.utils import (
direct_register_custom_op,
get_distributed_init_method, get_distributed_init_method,
supports_custom_op,
) )
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import (
direct_register_custom_op,
supports_custom_op,
)
@dataclass @dataclass

View File

@ -29,7 +29,8 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer from vllm.utils import get_tcp_uri
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -5,7 +5,7 @@ import os
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_torch_equal from vllm.utils.torch_utils import is_torch_equal
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -246,7 +246,7 @@ def maybe_convert_bool(value: str | None) -> bool | None:
def use_aot_compile() -> bool: def use_aot_compile() -> bool:
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0" default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"

View File

@ -12,7 +12,7 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit @triton.jit

View File

@ -12,7 +12,7 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit @triton.jit

View File

@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def flashinfer_fused_moe_blockscale_fp8( def flashinfer_fused_moe_blockscale_fp8(

View File

@ -52,8 +52,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Sc
from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled

View File

@ -52,8 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.worker.ubatching import dbo_current_ubatch_id from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():

View File

@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
class QuantMethod(IntEnum): class QuantMethod(IntEnum):

View File

@ -23,8 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize, mxfp8_e4m3_quantize,
) )
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_torch_equal_or_newer from vllm.utils import cdiv
from vllm.utils.flashinfer import flashinfer_fp4_quantize from vllm.utils.flashinfer import flashinfer_fp4_quantize
from vllm.utils.torch_utils import is_torch_equal_or_newer
@triton.jit @triton.jit

View File

@ -13,7 +13,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool: def is_rocm_aiter_rmsnorm_enabled() -> bool:

View File

@ -34,7 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -37,7 +37,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update, selective_state_update,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata

View File

@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader, sharded_weight_loader,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024

View File

@ -6,7 +6,10 @@ import torch
from vllm.config.cache import MambaDType from vllm.config.cache import MambaDType
from vllm.config.model import ModelDType from vllm.config.model import ModelDType
from vllm.distributed import divide from vllm.distributed import divide
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
get_kv_cache_torch_dtype,
)
class MambaStateDtypeCalculator: class MambaStateDtypeCalculator:

View File

@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
causal_conv1d_update, causal_conv1d_update,
) )
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata

View File

@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import (
QuantizationMethods, QuantizationMethods,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
class BitsAndBytesConfig(QuantizationConfig): class BitsAndBytesConfig(QuantizationConfig):

View File

@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
class FPQuantConfig(QuantizationConfig): class FPQuantConfig(QuantizationConfig):

View File

@ -28,7 +28,7 @@ from vllm.model_executor.layers.quantization.base_config import (
) )
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -7,7 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig

View File

@ -49,10 +49,10 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import ( from vllm.utils import (
has_triton_kernels, has_triton_kernels,
is_torch_equal_or_newer,
round_up, round_up,
) )
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -45,7 +45,7 @@ try:
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant from aiter.ops.triton.quant import dynamic_mxfp4_quant
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
if is_rocm_aiter_fp4_asm_gemm_enabled(): if is_rocm_aiter_fp4_asm_gemm_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip from aiter import gemm_a4w4, per_1x32_f4_quant_hip

View File

@ -28,13 +28,13 @@ from vllm.model_executor.parameter import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
fp8_gemm_nt, fp8_gemm_nt,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear, should_use_deepgemm_for_fp8_linear,
) )
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -7,7 +7,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -3,7 +3,7 @@
import torch import torch
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def _quant_dequant_mxfp6( def _quant_dequant_mxfp6(

View File

@ -12,8 +12,8 @@ from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from vllm.utils.torch_utils import direct_register_custom_op
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale

View File

@ -10,7 +10,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb

View File

@ -5,7 +5,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_triton_rotary_embedding_enabled() -> bool: def is_rocm_triton_rotary_embedding_enabled() -> bool:

View File

@ -9,7 +9,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def shuffle_weight(w: torch.Tensor) -> torch.Tensor: def shuffle_weight(w: torch.Tensor) -> torch.Tensor:

View File

@ -11,8 +11,8 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
initialize_model, initialize_model,
process_weights_after_loading, process_weights_after_loading,
set_default_torch_dtype,
) )
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype from vllm.model_executor.model_loader.utils import ParamMapping
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_safetensors_index_file_from_hf,
download_weights_from_hf, download_weights_from_hf,
@ -48,6 +48,7 @@ from vllm.model_executor.utils import (
set_weight_attrs, set_weight_attrs,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -15,13 +15,13 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
initialize_model, initialize_model,
process_weights_after_loading, process_weights_after_loading,
set_default_torch_dtype,
) )
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names, get_gguf_extra_tensor_names,
get_gguf_weight_type_map, get_gguf_weight_type_map,
gguf_quant_weights_iterator, gguf_quant_weights_iterator,
) )
from vllm.utils.torch_utils import set_default_torch_dtype
class GGUFModelLoader(BaseModelLoader): class GGUFModelLoader(BaseModelLoader):

View File

@ -22,8 +22,8 @@ from vllm.model_executor.model_loader.tensorizer import (
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
get_model_architecture, get_model_architecture,
initialize_model, initialize_model,
set_default_torch_dtype,
) )
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -14,8 +14,8 @@ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
initialize_model, initialize_model,
process_weights_after_loading, process_weights_after_loading,
set_default_torch_dtype,
) )
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib
import inspect import inspect
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
@ -32,15 +31,6 @@ from vllm.utils import is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def initialize_model( def initialize_model(
vllm_config: VllmConfig, vllm_config: VllmConfig,
*, *,

View File

@ -6,7 +6,8 @@ from typing import TYPE_CHECKING
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up from vllm.utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -79,8 +79,8 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import ( from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend, DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadata,

View File

@ -18,7 +18,6 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.transformers.utils import replace_linear_class from vllm.model_executor.models.transformers.utils import replace_linear_class
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
@ -51,6 +50,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processo
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.collections import is_list_of from vllm.utils.collections import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import ( from .utils import (

View File

@ -51,8 +51,8 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import set_default_torch_num_threads
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_num_threads
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,

View File

@ -49,7 +49,6 @@ from vllm.model_executor.layers.resampler import (
Resampler2, Resampler2,
get_2d_sincos_pos_embed, get_2d_sincos_pos_embed,
) )
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -88,6 +87,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.collections import flatten_2d_lists from vllm.utils.collections import flatten_2d_lists
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import ( from .interfaces import (

Some files were not shown because too many files have changed in this diff Show More