mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 17:45:58 +08:00
[Bugfix] neuron: enable tensor parallelism (#7562)
Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com>
This commit is contained in:
parent
05826c887b
commit
760e9f71a8
@ -317,9 +317,10 @@ class EngineArgs:
|
|||||||
parser.add_argument('--block-size',
|
parser.add_argument('--block-size',
|
||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.block_size,
|
default=EngineArgs.block_size,
|
||||||
choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
|
choices=[8, 16, 32],
|
||||||
help='Token block size for contiguous chunks of '
|
help='Token block size for contiguous chunks of '
|
||||||
'tokens.')
|
'tokens. This is ignored on neuron devices and '
|
||||||
|
'set to max-model-len')
|
||||||
|
|
||||||
parser.add_argument('--enable-prefix-caching',
|
parser.add_argument('--enable-prefix-caching',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@ -793,7 +794,8 @@ class EngineArgs:
|
|||||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=self.block_size,
|
block_size=self.block_size if self.device != "neuron" else
|
||||||
|
self.max_model_len, # neuron needs block_size = max_model_len
|
||||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||||
swap_space=self.swap_space,
|
swap_space=self.swap_space,
|
||||||
cache_dtype=self.kv_cache_dtype,
|
cache_dtype=self.kv_cache_dtype,
|
||||||
|
|||||||
@ -4,7 +4,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import make_async
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
|
make_async)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -24,14 +25,17 @@ class NeuronExecutor(ExecutorBase):
|
|||||||
|
|
||||||
def _init_worker(self):
|
def _init_worker(self):
|
||||||
from vllm.worker.neuron_worker import NeuronWorker
|
from vllm.worker.neuron_worker import NeuronWorker
|
||||||
|
distributed_init_method = get_distributed_init_method(
|
||||||
|
get_ip(), get_open_port())
|
||||||
self.driver_worker = NeuronWorker(
|
self.driver_worker = NeuronWorker(
|
||||||
self.model_config,
|
model_config=self.model_config,
|
||||||
self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
self.scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
self.device_config,
|
device_config=self.device_config,
|
||||||
self.cache_config,
|
cache_config=self.cache_config,
|
||||||
)
|
local_rank=0,
|
||||||
|
rank=0,
|
||||||
|
distributed_init_method=distributed_init_method)
|
||||||
self.driver_worker.init_device()
|
self.driver_worker.init_device()
|
||||||
self.driver_worker.load_model()
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,8 @@ import torch.distributed
|
|||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig)
|
||||||
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
|
init_distributed_environment)
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
||||||
@ -24,12 +26,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.rank = rank
|
||||||
|
self.distributed_init_method = distributed_init_method
|
||||||
if self.model_config.trust_remote_code:
|
if self.model_config.trust_remote_code:
|
||||||
# note: lazy import to avoid importing torch before initializing
|
# note: lazy import to avoid importing torch before initializing
|
||||||
from vllm.utils import init_cached_hf_modules
|
from vllm.utils import init_cached_hf_modules
|
||||||
@ -40,6 +48,8 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
self.is_driver_worker = True
|
self.is_driver_worker = True
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
|
self.init_distributed_environment()
|
||||||
|
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
@ -98,3 +108,20 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
This is required for speculative decoding; it is not yet implemented.
|
This is required for speculative decoding; it is not yet implemented.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def init_distributed_environment(self):
|
||||||
|
"""Neuron uses transformers-neuronx for tensor parallelism.
|
||||||
|
|
||||||
|
vLLM still needs the environment inited when TP/PP > 1
|
||||||
|
"""
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=1,
|
||||||
|
rank=self.rank,
|
||||||
|
local_rank=self.local_rank,
|
||||||
|
distributed_init_method=self.distributed_init_method,
|
||||||
|
backend="gloo",
|
||||||
|
)
|
||||||
|
ensure_model_parallel_initialized(
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user