[Misc] Fix ImportError causing by triton (#9493)

This commit is contained in:
Mengqing Cao 2024-11-08 13:08:51 +08:00 committed by GitHub
parent ad39bd640c
commit 7371749d54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,12 +13,15 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
cuda_is_initialized, get_distributed_init_method,
get_open_port, get_vllm_instance_id, make_async,
update_environment_variables)
if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager
logger = init_logger(__name__)
@ -59,7 +62,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
torch.set_num_threads(default_omp_num_threads)
# workaround for https://github.com/vllm-project/vllm/issues/6103
if world_size > 1:
if HAS_TRITON and world_size > 1:
maybe_set_triton_cache_manager()
# Multiprocessing-based executor does not support multi-node setting.