From ba0bfd40e21cacfd5da6a1e43028a37258a29cb4 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 2 Oct 2023 15:36:09 -0700 Subject: [PATCH] TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181) --- .github/workflows/pylint.yml | 2 +- .github/workflows/yapf.yml | 2 +- .pylintrc | 2 +- format.sh | 5 +- tests/async_engine/api_server_async_engine.py | 1 + tests/async_engine/test_api_server.py | 3 + tests/async_engine/test_async_llm_engine.py | 4 +- tests/async_engine/test_request_tracker.py | 22 +- tests/conftest.py | 1 + tests/distributed/test_comm_ops.py | 82 ++++ tests/engine/test_detokenize.py | 1 + tests/kernels/test_activation.py | 12 +- tests/kernels/test_cache.py | 6 +- tests/kernels/test_pos_encoding.py | 2 +- tests/samplers/test_sampler.py | 3 +- .../layers/quantized_linear/__init__.py | 4 +- .../layers/quantized_linear/awq.py | 4 +- vllm/model_executor/layers/sampler.py | 6 +- vllm/model_executor/models/aquila.py | 42 +- vllm/model_executor/models/baichuan.py | 42 +- vllm/model_executor/models/bloom.py | 29 +- vllm/model_executor/models/falcon.py | 31 +- vllm/model_executor/models/gpt2.py | 49 +- vllm/model_executor/models/gpt_bigcode.py | 61 +-- vllm/model_executor/models/gpt_j.py | 61 +-- vllm/model_executor/models/gpt_neox.py | 52 ++- vllm/model_executor/models/internlm.py | 44 +- vllm/model_executor/models/llama.py | 12 +- vllm/model_executor/models/mistral.py | 12 +- vllm/model_executor/models/mpt.py | 36 +- vllm/model_executor/models/opt.py | 51 ++- vllm/model_executor/models/qwen.py | 14 +- .../model_executor/parallel_utils/__init__.py | 7 - .../parallel_utils/communication_op.py | 47 ++ vllm/model_executor/parallel_utils/layers.py | 303 +++++++++++++ .../parallel_utils/parallel_state.py | 426 +++--------------- .../tensor_parallel/__init__.py | 50 -- .../parallel_utils/tensor_parallel/layers.py | 366 --------------- .../tensor_parallel/mappings.py | 281 ------------ .../parallel_utils/tensor_parallel/random.py | 164 ------- .../{tensor_parallel => }/utils.py | 18 +- vllm/model_executor/utils.py | 6 - 42 files changed, 819 insertions(+), 1547 deletions(-) create mode 100644 tests/distributed/test_comm_ops.py create mode 100644 vllm/model_executor/parallel_utils/communication_op.py create mode 100644 vllm/model_executor/parallel_utils/layers.py delete mode 100644 vllm/model_executor/parallel_utils/tensor_parallel/__init__.py delete mode 100644 vllm/model_executor/parallel_utils/tensor_parallel/layers.py delete mode 100644 vllm/model_executor/parallel_utils/tensor_parallel/mappings.py delete mode 100644 vllm/model_executor/parallel_utils/tensor_parallel/random.py rename vllm/model_executor/parallel_utils/{tensor_parallel => }/utils.py (86%) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 5e096f3c6e757..1c810adbe3ef4 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -28,4 +28,4 @@ jobs: pip install pylint==2.8.2 - name: Analysing the code with pylint run: | - pylint vllm + pylint vllm tests diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index 590e27597ecdc..b77a1a470107b 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -28,4 +28,4 @@ jobs: pip install toml==0.10.2 - name: Running yapf run: | - yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**' + yapf --diff --recursive vllm tests diff --git a/.pylintrc b/.pylintrc index 911d1bc7e2a5e..f85ab742bec30 100644 --- a/.pylintrc +++ b/.pylintrc @@ -8,7 +8,7 @@ [MASTER] # Files or directories to be skipped. They should be base names, not paths. -ignore=docs,parallel_utils +ignore=docs # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. diff --git a/format.sh b/format.sh index 4fd8f2bd49cb3..9d6d5b6ea4e72 100755 --- a/format.sh +++ b/format.sh @@ -44,7 +44,6 @@ YAPF_FLAGS=( YAPF_EXCLUDES=( '--exclude' 'build/**' - '--exclude' 'vllm/model_executor/parallel_utils/**' ) # Format specified files @@ -72,7 +71,7 @@ format_changed() { # Format all files format_all() { - yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm + yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm tests } ## This flag formats individual files. --files *must* be the first command line @@ -96,7 +95,7 @@ echo 'vLLM yapf: Done' # Run Pylint echo 'vLLM Pylint:' -pylint vllm +pylint vllm tests if ! git diff --quiet &>/dev/null; then echo 'Reformatted files. Please review and stage the changes.' diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py index 1be76fdc8d868..515d7a801e9be 100644 --- a/tests/async_engine/api_server_async_engine.py +++ b/tests/async_engine/api_server_async_engine.py @@ -14,6 +14,7 @@ app = vllm.entrypoints.api_server.app class AsyncLLMEngineWithStats(AsyncLLMEngine): + # pylint: disable=redefined-outer-name def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._num_aborts = 0 diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index dee62b9a6a96b..1ca4826b27f3b 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -24,6 +24,7 @@ def _query_server(prompt: str) -> dict: def api_server(): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() + # pylint: disable=consider-using-with uvicorn_process = subprocess.Popen([ sys.executable, "-u", str(script_path), "--model", "facebook/opt-125m" @@ -32,6 +33,7 @@ def api_server(): uvicorn_process.terminate() +# pylint: disable=redefined-outer-name, unused-argument def test_api_server(api_server): """ Run the API server and test it. @@ -47,6 +49,7 @@ def test_api_server(api_server): prompts = ["Hello world"] * 1 result = None while not result: + # pylint: disable=bare-except try: for result in pool.map(_query_server, prompts): break diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 44ad201e914be..174975802dc0d 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -32,12 +32,12 @@ class MockEngine: self.request_id = None def add_request(self, **kwargs): + del kwargs # Unused self.add_request_calls += 1 - return def abort_request(self, request_id): + del request_id # Unused self.abort_request_calls += 1 - return class MockAsyncLLMEngine(AsyncLLMEngine): diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py index 7787381f97d10..83b306e75950d 100644 --- a/tests/async_engine/test_request_tracker.py +++ b/tests/async_engine/test_request_tracker.py @@ -7,22 +7,22 @@ from vllm.outputs import RequestOutput class DummyEvent: def __init__(self): - self._flag = False + self.flag = False def set(self): - self._flag = True + self.flag = True def clear(self): - self._flag = False + self.flag = False def test_request_tracker(): tracker = RequestTracker() tracker.new_requests_event = DummyEvent() stream_1 = tracker.add_request("1") - assert tracker.new_requests_event._flag + assert tracker.new_requests_event.flag new, finished = tracker.get_new_and_finished_requests() - assert not tracker.new_requests_event._flag + assert not tracker.new_requests_event.flag assert len(new) == 1 assert new[0]["request_id"] == "1" assert not finished @@ -30,9 +30,9 @@ def test_request_tracker(): stream_2 = tracker.add_request("2") stream_3 = tracker.add_request("3") - assert tracker.new_requests_event._flag + assert tracker.new_requests_event.flag new, finished = tracker.get_new_and_finished_requests() - assert not tracker.new_requests_event._flag + assert not tracker.new_requests_event.flag assert len(new) == 2 assert new[0]["request_id"] == "2" assert new[1]["request_id"] == "3" @@ -43,7 +43,7 @@ def test_request_tracker(): # request_ids must be unique with pytest.raises(KeyError): tracker.add_request("1") - assert not tracker.new_requests_event._flag + assert not tracker.new_requests_event.flag tracker.abort_request("1") new, finished = tracker.get_new_and_finished_requests() @@ -54,7 +54,7 @@ def test_request_tracker(): stream_4 = tracker.add_request("4") tracker.abort_request("4") - assert tracker.new_requests_event._flag + assert tracker.new_requests_event.flag new, finished = tracker.get_new_and_finished_requests() assert len(finished) == 1 assert "4" in finished @@ -62,11 +62,11 @@ def test_request_tracker(): assert stream_4.finished stream_5 = tracker.add_request("5") - assert tracker.new_requests_event._flag + assert tracker.new_requests_event.flag tracker.process_request_output( RequestOutput("2", "output", [], [], finished=True)) new, finished = tracker.get_new_and_finished_requests() - assert not tracker.new_requests_event._flag + assert not tracker.new_requests_event.flag assert len(finished) == 1 assert "2" in finished assert len(new) == 1 diff --git a/tests/conftest.py b/tests/conftest.py index 686303e6bebbe..24b99a27929da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ from vllm import LLM, SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer _TEST_PROMPTS = [ + # pylint: disable=line-too-long "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py new file mode 100644 index 0000000000000..9de0ce0de416e --- /dev/null +++ b/tests/distributed/test_comm_ops.py @@ -0,0 +1,82 @@ +"""Test the communication operators. + +Run `pytest tests/distributed/test_comm_ops.py --forked`. +""" +from multiprocessing import Process + +import pytest +import torch + +from vllm.config import ParallelConfig +from vllm.engine.ray_utils import get_open_port +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce, + tensor_model_parallel_all_gather, +) +from vllm.worker.worker import _init_distributed_environment + + +def init_test_distributed_environment(pipeline_parallel_size: int, + tensor_parallel_size: int, rank: int, + distributed_init_port: str): + parallel_config = ParallelConfig(pipeline_parallel_size, + tensor_parallel_size, + worker_use_ray=True) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + torch.cuda.set_device(rank) + _init_distributed_environment(parallel_config, rank, + distributed_init_method) + + +def all_reduce_test_worker(tensor_parallel_size: int, rank: int, + distributed_init_port: str): + init_test_distributed_environment(1, tensor_parallel_size, rank, + distributed_init_port) + num_elements = 8 + all_tensors = [ + torch.arange(num_elements, dtype=torch.float32, device="cuda") * + (r + 1) for r in range(tensor_parallel_size) + ] + expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) + t = all_tensors[rank] + t = tensor_model_parallel_all_reduce(t) + assert torch.allclose(t, expected) + + +def all_gather_test_worker(tensor_parallel_size: int, rank: int, + distributed_init_port: str): + init_test_distributed_environment(1, tensor_parallel_size, rank, + distributed_init_port) + num_dimensions = 3 + tensor_size = list(range(2, num_dimensions + 2)) + total_size = 1 + for s in tensor_size: + total_size *= s + for all_gather_dimension in range(num_dimensions): + all_tensors = [ + torch.arange(total_size, dtype=torch.float32, + device="cuda").reshape(tensor_size) * (r + 1) + for r in range(tensor_parallel_size) + ] + expected = torch.cat(all_tensors, dim=all_gather_dimension) + t = all_tensors[rank] + t = tensor_model_parallel_all_gather(t, all_gather_dimension) + assert torch.allclose(t, expected) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("test_target", + [all_reduce_test_worker, all_gather_test_worker]) +def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): + distributed_init_port = get_open_port() + processes = [] + for rank in range(tensor_parallel_size): + p = Process(target=test_target, + args=(tensor_parallel_size, rank, distributed_init_port)) + p.start() + processes.append(p) + for p in processes: + p.join() + assert all(p.exitcode == 0 for p in processes) diff --git a/tests/engine/test_detokenize.py b/tests/engine/test_detokenize.py index fc5936c7434e8..0f51af166c4b1 100644 --- a/tests/engine/test_detokenize.py +++ b/tests/engine/test_detokenize.py @@ -5,6 +5,7 @@ from transformers import AutoTokenizer from vllm.transformers_utils.tokenizer import detokenize_incrementally TRUTH = [ + # pylint: disable=line-too-long "Hello here, this is a simple test", "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", "我很感谢你的热情" diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 8aa35d2b2340f..0b3ad0aa255a1 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -29,8 +29,8 @@ def test_silu_and_mul( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') - out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.silu_and_mul(out, x) ref_out = ref_silu_and_mul(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -49,8 +49,8 @@ def test_gelu_new( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') - out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.gelu_new(out, x) ref_out = get_activation("gelu_new")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -68,8 +68,8 @@ def test_gelu_fast( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') - out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.gelu_fast(out, x) ref_out = get_activation("gelu_fast")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index cca037df235dc..b72dfbd6688e3 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -106,14 +106,14 @@ def test_reshape_and_cache( # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda") qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, - device='cuda') + device="cuda") _, key, value = qkv.unbind(dim=1) # Create the KV caches. @@ -132,7 +132,7 @@ def test_reshape_and_cache( # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) - block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") block_indicies = block_indicies.cpu().tolist() block_offsets = slot_mapping % block_size block_offsets = block_offsets.cpu().tolist() diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 0d255900d4c11..d660417440844 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -140,7 +140,7 @@ def test_rotary_embedding( cos = freqs.cos() sin = freqs.sin() cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") # Run the kernel. The kernel is in-place, so we need to clone the inputs. out_query = query.clone() diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index a5f55d50fbb76..74c819efea23b 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,3 +1,4 @@ +# pylint: disable=protected-access import pytest import random from typing import Tuple @@ -108,7 +109,7 @@ def test_sampler_all_random(seed: int): def test_sampler_all_beam(seed: int): set_random_seed(seed) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) + input_tensor, _, sampler, worker = _prepare_test(batch_size) seq_group_metadata_list = [] for i in range(batch_size): diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index bcb9a54e7a2c6..eecfe8149ebf3 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -1,7 +1,7 @@ from vllm.model_executor.layers.quantized_linear.awq import ( AWQColumnParallelLinear, AWQRowParallelLinear) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, + RowParallelLinear) _QUANTIZED_LINEAR_REGISTRY = { "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index 88c5790acfd4a..2c2d0f8ca1f9d 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -4,8 +4,8 @@ import torch from torch.nn.parameter import Parameter from vllm import quantization_ops -from vllm.model_executor.parallel_utils.tensor_parallel.layers import ( - ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, + RowParallelLinear) class AWQColumnParallelLinear(ColumnParallelLinear): diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 9d25c60586b1f..76442eae680d2 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -5,8 +5,8 @@ import torch import torch.nn as nn from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.parallel_utils.tensor_parallel import ( - gather_from_tensor_model_parallel_region) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_gather) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs @@ -92,7 +92,7 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: logits += embedding_bias - logits = gather_from_tensor_model_parallel_region(logits) + logits = tensor_model_parallel_all_gather(logits) # Remove paddings in vocab (if any). logits = logits[:, :vocab_size] return logits diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index caf9b61ffa0e6..33280d9aefe64 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import ( load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.aquila import AquilaConfig @@ -56,16 +57,18 @@ class AquilaMLP(nn.Module): hidden_act: str, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - perform_initialization=False) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - perform_initialization=False) + self.gate_up_proj = ColumnParallelLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -130,14 +133,12 @@ class AquilaAttention(nn.Module): self.head_dim, bias=False, gather_output=False, - perform_initialization=False, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, - perform_initialization=False, ) self.attn = PagedAttentionWithRoPE( self.num_heads, @@ -230,7 +231,7 @@ class AquilaModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - perform_initialization=False) + ) self.layers = nn.ModuleList([ AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers) ]) @@ -270,11 +271,12 @@ class AquilaForCausalLM(nn.Module): self.config = config self.model = AquilaModel(config) vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + ) self.sampler = Sampler(config.vocab_size) def forward( diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 277b2cc49b442..7d0454271a799 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import ( load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.baichuan import BaiChuanConfig @@ -81,16 +82,18 @@ class BaiChuanMLP(nn.Module): hidden_act: str, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - perform_initialization=False) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - perform_initialization=False) + self.gate_up_proj = ColumnParallelLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -133,14 +136,12 @@ class BaiChuanAttention(nn.Module): 3 * hidden_size, bias=False, gather_output=False, - perform_initialization=False, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, - perform_initialization=False, ) # Create the alibi slopes and slice them. if self.postion_embedding == "ALIBI": @@ -249,7 +250,7 @@ class BaiChuanModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - perform_initialization=False) + ) self.layers = nn.ModuleList([ BaiChuanDecoderLayer(config, position_embedding) for _ in range(config.num_hidden_layers) @@ -288,11 +289,12 @@ class BaiChuanBaseForCausalLM(nn.Module): super().__init__() self.config = config self.model = BaiChuanModel(config, position_embedding) - self.lm_head = ColumnParallelLinear(config.hidden_size, - config.vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + ) self.sampler = Sampler(config.vocab_size) def forward( diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index d7f7d1910bc5a..f3bb17655c5b3 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -85,14 +86,12 @@ class BloomAttention(nn.Module): 3 * self.hidden_size, bias=True, gather_output=False, - perform_initialization=False, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, input_is_parallel=True, - perform_initialization=False, ) # Create the alibi slopes and slice them. @@ -129,15 +128,17 @@ class BloomMLP(nn.Module): def __init__(self, config: BloomConfig): super().__init__() hidden_size = config.hidden_size - self.dense_h_to_4h = ColumnParallelLinear(hidden_size, - 4 * hidden_size, - gather_output=False, - perform_initialization=False) + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + 4 * hidden_size, + gather_output=False, + ) self.act = get_act_fn("gelu") - self.dense_4h_to_h = RowParallelLinear(4 * hidden_size, - hidden_size, - input_is_parallel=True, - perform_initialization=False) + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + input_is_parallel=True, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.dense_h_to_4h(x) @@ -208,7 +209,9 @@ class BloomModel(nn.Module): # Embedding + LN Embedding self.word_embeddings = VocabParallelEmbedding( - config.vocab_size, self.embed_dim, perform_initialization=False) + config.vocab_size, + self.embed_dim, + ) self.word_embeddings_layernorm = nn.LayerNorm( self.embed_dim, eps=config.layer_norm_epsilon) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index e8e2171fe7552..6c249f6c98fec 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -36,9 +36,11 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear, - reduce_from_tensor_model_parallel_region) +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import RWConfig @@ -109,7 +111,6 @@ class FalconAttention(nn.Module): self.head_dim, bias=config.bias, gather_output=False, - perform_initialization=False, skip_bias_add=True, ) elif self.multi_query: @@ -120,7 +121,6 @@ class FalconAttention(nn.Module): self.total_num_heads * self.head_dim, bias=config.bias, gather_output=False, - perform_initialization=False, skip_bias_add=True, ) self.key_value = FalconLinear(self.hidden_size, @@ -135,7 +135,6 @@ class FalconAttention(nn.Module): self.head_dim, bias=config.bias, gather_output=False, - perform_initialization=False, skip_bias_add=True, ) @@ -151,7 +150,6 @@ class FalconAttention(nn.Module): self.hidden_size, bias=config.bias, input_is_parallel=True, - perform_initialization=False, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results) @@ -231,7 +229,6 @@ class FalconMLP(nn.Module): 4 * hidden_size, bias=config.bias, gather_output=False, - perform_initialization=False, skip_bias_add=True) self.act = nn.GELU() self.reduce_row_parallel_results = not (config.new_decoder_architecture @@ -241,7 +238,6 @@ class FalconMLP(nn.Module): hidden_size, bias=config.bias, input_is_parallel=True, - perform_initialization=False, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results) @@ -325,7 +321,7 @@ class FalconDecoderLayer(nn.Module): # only one all-reduce operator to reduce the results from # both MLP and Attention layers. mlp_output += attention_output - mlp_output = reduce_from_tensor_model_parallel_region(mlp_output) + mlp_output = tensor_model_parallel_all_reduce(mlp_output) if attention_bias is not None: mlp_output += attention_bias if mlp_bias is not None: @@ -347,7 +343,9 @@ class FalconModel(nn.Module): # Embedding + LN Embedding self.word_embeddings = VocabParallelEmbedding( - config.vocab_size, self.embed_dim, perform_initialization=False) + config.vocab_size, + self.embed_dim, + ) # Transformer blocks self.h = nn.ModuleList([ @@ -389,11 +387,12 @@ class FalconForCausalLM(nn.Module): super().__init__() self.config = config self.transformer = FalconModel(config) - self.lm_head = ColumnParallelLinear(config.hidden_size, - config.vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + ) self.sampler = Sampler(config.vocab_size) def forward( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fe7e009aeaf76..b9309eb956544 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -36,8 +36,9 @@ from vllm.model_executor.weight_utils import ( load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -56,16 +57,18 @@ class GPT2Attention(nn.Module): self.head_dim = self.hidden_size // total_num_heads self.scale = self.head_dim**-0.5 - self.c_attn = ColumnParallelLinear(self.hidden_size, - 3 * self.hidden_size, - bias=True, - gather_output=False, - perform_initialization=False) - self.c_proj = RowParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - input_is_parallel=True, - perform_initialization=False) + self.c_attn = ColumnParallelLinear( + self.hidden_size, + 3 * self.hidden_size, + bias=True, + gather_output=False, + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + input_is_parallel=True, + ) self.attn = PagedAttention(self.num_heads, self.head_dim, scale=self.scale) @@ -95,16 +98,18 @@ class GPT2MLP(nn.Module): ): super().__init__() hidden_size = config.hidden_size - self.c_fc = ColumnParallelLinear(hidden_size, - intermediate_size, - bias=True, - gather_output=False, - perform_initialization=False) - self.c_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=True, - input_is_parallel=True, - perform_initialization=False) + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + gather_output=False, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + input_is_parallel=True, + ) self.act = get_act_fn(config.activation_function) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 049b4622839a7..41f72c8cb7086 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -37,8 +37,9 @@ from vllm.model_executor.weight_utils import ( load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -62,29 +63,31 @@ class GPTBigCodeAttention(nn.Module): if self.multi_query: self.num_kv_heads = 1 self.kv_dim = self.head_dim - self.c_attn_q = ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - gather_output=False, - perform_initialization=False) + self.c_attn_q = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + gather_output=False, + ) self.c_attn_kv = nn.Linear(self.hidden_size, 2 * self.kv_dim, bias=True) else: self.num_kv_heads = self.num_heads self.kv_dim = self.num_kv_heads * self.head_dim - self.c_attn = ColumnParallelLinear(self.hidden_size, - self.hidden_size + - 2 * self.kv_dim, - bias=True, - gather_output=False, - perform_initialization=False) + self.c_attn = ColumnParallelLinear( + self.hidden_size, + self.hidden_size + 2 * self.kv_dim, + bias=True, + gather_output=False, + ) - self.c_proj = RowParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - input_is_parallel=True, - perform_initialization=False) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + input_is_parallel=True, + ) self.attn = PagedAttention(self.num_heads, self.head_dim, scale=self.scale, @@ -124,16 +127,18 @@ class GPTBigMLP(nn.Module): ): super().__init__() hidden_size = config.hidden_size - self.c_fc = ColumnParallelLinear(hidden_size, - intermediate_size, - bias=True, - gather_output=False, - perform_initialization=False) - self.c_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=True, - input_is_parallel=True, - perform_initialization=False) + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + gather_output=False, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + input_is_parallel=True, + ) self.act = get_act_fn(config.activation_function) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index f8ffcdb7189a5..3606fdc76fb15 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -49,16 +50,18 @@ class GPTJAttention(nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.total_num_heads - self.qkv_proj = ColumnParallelLinear(config.hidden_size, - 3 * config.hidden_size, - bias=False, - gather_output=False, - perform_initialization=False) - self.out_proj = RowParallelLinear(config.hidden_size, - config.hidden_size, - bias=False, - input_is_parallel=True, - perform_initialization=False) + self.qkv_proj = ColumnParallelLinear( + config.hidden_size, + 3 * config.hidden_size, + bias=False, + gather_output=False, + ) + self.out_proj = RowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + ) tp_world_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tp_world_size == 0 @@ -102,14 +105,16 @@ class GPTJMLP(nn.Module): def __init__(self, intermediate_size: int, config: GPTJConfig): super().__init__() hidden_size = config.n_embd - self.fc_in = ColumnParallelLinear(hidden_size, - intermediate_size, - gather_output=False, - perform_initialization=False) - self.fc_out = RowParallelLinear(intermediate_size, - hidden_size, - input_is_parallel=True, - perform_initialization=False) + self.fc_in = ColumnParallelLinear( + hidden_size, + intermediate_size, + gather_output=False, + ) + self.fc_out = RowParallelLinear( + intermediate_size, + hidden_size, + input_is_parallel=True, + ) self.act = get_act_fn(config.activation_function) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -159,9 +164,10 @@ class GPTJModel(nn.Module): super().__init__() self.config = config self.embed_dim = config.n_embd - self.wte = VocabParallelEmbedding(config.vocab_size, - self.embed_dim, - perform_initialization=False) + self.wte = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) self.h = nn.ModuleList( [GPTJBlock(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -199,10 +205,11 @@ class GPTJForCausalLM(nn.Module): self.config = config assert not config.tie_word_embeddings self.transformer = GPTJModel(config) - self.lm_head = ColumnParallelLinear(config.n_embd, - config.vocab_size, - gather_output=False, - perform_initialization=False) + self.lm_head = ColumnParallelLinear( + config.n_embd, + config.vocab_size, + gather_output=False, + ) self.sampler = Sampler(config.vocab_size) def forward( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 225726a630cf5..d0187c93c541e 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -59,11 +60,12 @@ class GPTNeoXAttention(nn.Module): config.hidden_size, 3 * config.hidden_size, gather_output=False, - perform_initialization=False) - self.dense = RowParallelLinear(config.hidden_size, - config.hidden_size, - input_is_parallel=True, - perform_initialization=False) + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + input_is_parallel=True, + ) scaling = self.head_size**-0.5 rotary_dim = int(self.head_size * config.rotary_pct) @@ -100,14 +102,16 @@ class GPTNeoXMLP(nn.Module): def __init__(self, config: GPTNeoXConfig): super().__init__() - self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - gather_output=False, - perform_initialization=False) - self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, - config.hidden_size, - input_is_parallel=True, - perform_initialization=False) + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + gather_output=False, + ) + self.dense_4h_to_h = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + input_is_parallel=True, + ) self.act = get_act_fn(config.hidden_act) def forward(self, hidden_states): @@ -169,9 +173,10 @@ class GPTNeoXModel(nn.Module): super().__init__() self.config = config - self.embed_in = VocabParallelEmbedding(config.vocab_size, - config.hidden_size, - perform_initialization=False) + self.embed_in = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) self.layers = nn.ModuleList( [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, @@ -209,11 +214,12 @@ class GPTNeoXForCausalLM(nn.Module): super().__init__() self.config = config self.gpt_neox = GPTNeoXModel(config) - self.embed_out = ColumnParallelLinear(config.hidden_size, - config.vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) + self.embed_out = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + ) self.sampler = Sampler(config.vocab_size) def forward( diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 55bd76be01409..ce35eaf4f6a38 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -12,8 +12,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding) from vllm.model_executor.weight_utils import ( hf_model_weights_iterator, load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) @@ -31,16 +32,18 @@ class InternLMMLP(nn.Module): hidden_act: str, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - perform_initialization=False) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - perform_initialization=False) + self.gate_up_proj = ColumnParallelLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -80,14 +83,12 @@ class InternLMAttention(nn.Module): 3 * self.total_num_heads * self.head_dim, bias=True, gather_output=False, - perform_initialization=False, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=True, input_is_parallel=True, - perform_initialization=False, ) self.attn = PagedAttentionWithRoPE( self.num_heads, @@ -176,7 +177,9 @@ class InternLMModel(nn.Module): vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - vocab_size, config.hidden_size, perform_initialization=False) + vocab_size, + config.hidden_size, + ) self.layers = nn.ModuleList([ InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers) @@ -216,11 +219,12 @@ class InternLMForCausalLM(nn.Module): self.config = config self.model = InternLMModel(config) vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + ) self.sampler = Sampler(config.vocab_size) def forward( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 287d650da9806..127fd080dc3a8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -39,8 +39,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import ( convert_pyslice_to_tensor, hf_model_weights_iterator, @@ -64,13 +63,11 @@ class LlamaMLP(nn.Module): 2 * intermediate_size, bias=False, gather_output=False, - perform_initialization=False, quant_config=quant_config) self.down_proj = ParallelLinear.row(intermediate_size, hidden_size, bias=False, input_is_parallel=True, - perform_initialization=False, quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -127,7 +124,6 @@ class LlamaAttention(nn.Module): self.head_dim, bias=False, gather_output=False, - perform_initialization=False, quant_config=quant_config, ) self.o_proj = ParallelLinear.row( @@ -135,7 +131,6 @@ class LlamaAttention(nn.Module): hidden_size, bias=False, input_is_parallel=True, - perform_initialization=False, quant_config=quant_config, ) self.attn = PagedAttentionWithRoPE( @@ -241,7 +236,9 @@ class LlamaModel(nn.Module): vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - vocab_size, config.hidden_size, perform_initialization=False) + vocab_size, + config.hidden_size, + ) self.layers = nn.ModuleList([ LlamaDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) @@ -291,7 +288,6 @@ class LlamaForCausalLM(nn.Module): vocab_size, bias=False, gather_output=False, - perform_initialization=False, quant_config=None) self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index f2a4faa18b17d..d298ea7d2be4e 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -38,8 +38,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import ( convert_pyslice_to_tensor, hf_model_weights_iterator, @@ -64,13 +63,11 @@ class MistralMLP(nn.Module): 2 * intermediate_size, bias=False, gather_output=False, - perform_initialization=False, quant_config=quant_config) self.down_proj = ParallelLinear.row(intermediate_size, hidden_size, bias=False, input_is_parallel=True, - perform_initialization=False, quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -116,7 +113,6 @@ class MistralAttention(nn.Module): self.head_dim, bias=False, gather_output=False, - perform_initialization=False, quant_config=quant_config, ) self.o_proj = ParallelLinear.row( @@ -124,7 +120,6 @@ class MistralAttention(nn.Module): hidden_size, bias=False, input_is_parallel=True, - perform_initialization=False, quant_config=quant_config, ) self.attn = PagedAttentionWithRoPE(self.num_heads, @@ -225,7 +220,9 @@ class MistralModel(nn.Module): vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - vocab_size, config.hidden_size, perform_initialization=False) + vocab_size, + config.hidden_size, + ) self.layers = nn.ModuleList([ MistralDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) @@ -275,7 +272,6 @@ class MistralForCausalLM(nn.Module): vocab_size, bias=False, gather_output=False, - perform_initialization=False, quant_config=None) self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 293d77b6aa1d9..ba7441e145b16 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -15,8 +15,9 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.mpt import MPTConfig @@ -53,7 +54,6 @@ class MPTAttention(nn.Module): 3 * self.d_model, bias=not config.no_bias, gather_output=False, - perform_initialization=False, ) if self.qk_ln: self.q_ln = nn.LayerNorm(self.d_model) @@ -63,7 +63,6 @@ class MPTAttention(nn.Module): self.d_model, bias=not config.no_bias, input_is_parallel=True, - perform_initialization=False, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -113,17 +112,19 @@ class MPTMLP(nn.Module): hidden_size = config.d_model expansion_ratio = config.expansion_ratio intermediate_size = expansion_ratio * hidden_size - self.up_proj = ColumnParallelLinear(hidden_size, - intermediate_size, - bias=not config.no_bias, - gather_output=False, - perform_initialization=False) + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=not config.no_bias, + gather_output=False, + ) self.act = get_act_fn("gelu") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=not config.no_bias, - input_is_parallel=True, - perform_initialization=False) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=not config.no_bias, + input_is_parallel=True, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.up_proj(x) @@ -172,9 +173,10 @@ class MPTModel(nn.Module): assert config.embedding_fraction == 1.0 assert config.norm_type == "low_precision_layernorm" - self.wte = VocabParallelEmbedding(config.vocab_size, - config.d_model, - perform_initialization=False) + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + ) self.blocks = nn.ModuleList( [MPTBlock(config) for _ in range(config.n_layers)]) self.norm_f = nn.LayerNorm(config.d_model) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 2064e1aec2afa..5295c73981856 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -73,16 +74,18 @@ class OPTAttention(nn.Module): self.head_dim = embed_dim // total_num_heads self.scaling = self.head_dim**-0.5 - self.qkv_proj = ColumnParallelLinear(embed_dim, - 3 * embed_dim, - bias=bias, - gather_output=False, - perform_initialization=False) - self.out_proj = RowParallelLinear(embed_dim, - embed_dim, - bias=bias, - input_is_parallel=True, - perform_initialization=False) + self.qkv_proj = ColumnParallelLinear( + embed_dim, + 3 * embed_dim, + bias=bias, + gather_output=False, + ) + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + input_is_parallel=True, + ) self.attn = PagedAttention(self.num_heads, self.head_dim, scale=self.scaling) @@ -120,16 +123,18 @@ class OPTDecoderLayer(nn.Module): self.self_attn_layer_norm = nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) - self.fc1 = ColumnParallelLinear(self.embed_dim, - config.ffn_dim, - bias=config.enable_bias, - gather_output=False, - perform_initialization=False) - self.fc2 = RowParallelLinear(config.ffn_dim, - self.embed_dim, - bias=config.enable_bias, - input_is_parallel=True, - perform_initialization=False) + self.fc1 = ColumnParallelLinear( + self.embed_dim, + config.ffn_dim, + bias=config.enable_bias, + gather_output=False, + ) + self.fc2 = RowParallelLinear( + config.ffn_dim, + self.embed_dim, + bias=config.enable_bias, + input_is_parallel=True, + ) self.final_layer_norm = nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) @@ -182,7 +187,7 @@ class OPTDecoder(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.word_embed_proj_dim, - perform_initialization=False) + ) # Positional embeddings are replicated (not sharded). self.embed_positions = OPTLearnedPositionalEmbedding( config.max_position_embeddings, config.hidden_size) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 0a3213becb653..bd5280b35cc34 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -28,7 +28,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.parallel_utils.tensor_parallel import ( +from vllm.model_executor.parallel_utils.layers import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear, @@ -53,14 +53,12 @@ class QWenMLP(nn.Module): 2 * intermediate_size, bias=False, gather_output=False, - perform_initialization=False, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, input_is_parallel=True, - perform_initialization=False, ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -98,14 +96,12 @@ class QWenAttention(nn.Module): 3 * hidden_size, bias=True, gather_output=False, - perform_initialization=False, ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, - perform_initialization=False, ) self.scaling = self.head_dim**-0.5 self.attn = PagedAttentionWithRoPE( @@ -190,9 +186,10 @@ class QWenModel(nn.Module): self.vocab_size = config.vocab_size vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.wte = VocabParallelEmbedding(vocab_size, - config.hidden_size, - perform_initialization=False) + self.wte = VocabParallelEmbedding( + vocab_size, + config.hidden_size, + ) self.h = nn.ModuleList( [QWenBlock(config) for _ in range(config.num_hidden_layers)]) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -235,7 +232,6 @@ class QWenLMHeadModel(nn.Module): vocab_size, bias=False, gather_output=False, - perform_initialization=False, ) self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/parallel_utils/__init__.py b/vllm/model_executor/parallel_utils/__init__.py index de13976a3f488..e69de29bb2d1d 100644 --- a/vllm/model_executor/parallel_utils/__init__.py +++ b/vllm/model_executor/parallel_utils/__init__.py @@ -1,7 +0,0 @@ -import vllm.model_executor.parallel_utils.parallel_state -import vllm.model_executor.parallel_utils.tensor_parallel - -__all__ = [ - "parallel_state", - "tensor_parallel", -] diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py new file mode 100644 index 0000000000000..f977397f9d54f --- /dev/null +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -0,0 +1,47 @@ +import torch + +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_group, +) + + +def tensor_model_parallel_all_reduce(input_): + """All-reduce the input tensor across model parallel group. + + Note: This operation is applied in-place on the input tensor. + """ + # Bypass the function if we are using only 1 GPU. + if get_tensor_model_parallel_world_size() == 1: + return input_ + # All-reduce. + torch.distributed.all_reduce(input_, + group=get_tensor_model_parallel_group()) + return input_ + + +def tensor_model_parallel_all_gather(input_, dim=-1): + """All-gather the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=get_tensor_model_parallel_group()) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor diff --git a/vllm/model_executor/parallel_utils/layers.py b/vllm/model_executor/parallel_utils/layers.py new file mode 100644 index 0000000000000..6b5ecc4c6a928 --- /dev/null +++ b/vllm/model_executor/parallel_utils/layers.py @@ -0,0 +1,303 @@ +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch +from typing import Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather) + +from vllm.model_executor.parallel_utils.utils import ( + divide, + VocabUtility, + split_tensor_along_last_dim, +) + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + params_dtype: type of the parameters. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None): + super().__init__() + + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = get_tensor_model_parallel_world_size() + # TODO: Handle vocab padding here. + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = ( + VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_tensor_model_parallel_rank(), + self.tp_size)) + self.num_embeddings_per_partition = (self.vocab_end_index - + self.vocab_start_index) + + self.weight = Parameter( + torch.empty(self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=params_dtype)) + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + input_mask = ((input_ < self.vocab_start_index) | + (input_ >= self.vocab_end_index)) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output_parallel = F.embedding(masked_input, self.weight) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configuration. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + self.tp_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, self.tp_size) + self.skip_bias_add = skip_bias_add + self.quant_config = quant_config + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + self.create_weights(params_dtype) + + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=params_dtype)) + else: + self.register_parameter('bias', None) + + def create_weights(self, dtype: torch.dtype) -> None: + self.weight = Parameter( + torch.empty(self.output_size_per_partition, + self.input_size, + device=torch.cuda.current_device(), + dtype=dtype)) + + def apply_weights( + self, + x: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + return F.linear(x, self.weight, bias) + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = self.bias if not self.skip_bias_add else None + + input_parallel = input_ + # Matrix multiply. + output_parallel = self.apply_weights(input_parallel, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments: + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configuration. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Divide the weight matrix along the last dimension. + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + self.skip_bias_add = skip_bias_add + self.quant_config = quant_config + + self.create_weights(params_dtype) + + if not reduce_results and (bias and not skip_bias_add): + raise ValueError('When not reduce the results, adding bias to the ' + 'results can lead to incorrect results') + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, + device=torch.cuda.current_device(), + dtype=params_dtype)) + + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def create_weights(self, dtype: torch.dtype) -> None: + self.weight = Parameter( + torch.empty(self.output_size, + self.input_size_per_partition, + device=torch.cuda.current_device(), + dtype=dtype)) + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, self.weight) + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply_weights(input_parallel) + if self.reduce_results and self.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index e5a43258ee10c..53871c85a8620 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -1,78 +1,42 @@ # Copyright 2023 The vLLM team. -# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - """Model and data parallel groups.""" import torch -from typing import Optional -# Intra-layer model parallel group that the current rank belongs to. +# Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None -# Inter-layer model parallel group that the current rank belongs to. +# Pipeline model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None -# Model parallel group (both intra- and pipeline) that the current rank belongs to. -_MODEL_PARALLEL_GROUP = None -# Embedding group. -_EMBEDDING_GROUP = None -# Position embedding group. -_POSITION_EMBEDDING_GROUP = None -# Data parallel group that the current rank belongs to. -_DATA_PARALLEL_GROUP = None -_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None -_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None -_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None - -# These values enable us to change the mpu sizes on the fly. -_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None -_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None -_MPU_TENSOR_MODEL_PARALLEL_RANK = None -_MPU_PIPELINE_MODEL_PARALLEL_RANK = None - -# A list of ranks that have a copy of the embedding. -_EMBEDDING_GLOBAL_RANKS = None - -# A list of ranks that have a copy of the position embedding. -_POSITION_EMBEDDING_GLOBAL_RANKS = None - -# A list of global ranks for each pipeline group to ease calculation of the source -# rank when broadcasting from the first or last pipeline stage. +# A list of global ranks for each pipeline group to ease calculation of the +# source rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None -# A list of global ranks for each data parallel group to ease calculation of the source -# rank when broadcasting weights from src to all other data parallel ranks -_DATA_PARALLEL_GLOBAL_RANKS = None - def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, - virtual_pipeline_model_parallel_size: Optional[int] = None, - pipeline_model_parallel_split_rank: Optional[int] = None, ) -> None: """ - Initialize model data parallel groups. + Initialize model parallel groups. Arguments: - tensor_model_parallel_size: number of GPUs used for tensor model parallelism. - pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. - virtual_pipeline_model_parallel_size: number of virtual stages (interleaved - pipeline). - pipeline_model_parallel_split_rank: for models with both encoder and decoder, - rank in pipeline with split point. + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. - Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize the model pipeline. The present function will - create 8 tensor model-parallel groups, 4 pipeline model-parallel groups - and 8 data-parallel groups as: - 8 data_parallel groups: - [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] - 8 tensor model-parallel groups: - [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] - 4 pipeline model-parallel groups: - [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and @@ -82,64 +46,23 @@ def initialize_model_parallel( assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() - if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: + if (world_size != + tensor_model_parallel_size * pipeline_model_parallel_size): raise RuntimeError( - f"world_size ({world_size}) is not divisible by tensor_model_parallel_size " - f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" - ) - - data_parallel_size: int = world_size // (tensor_model_parallel_size * - pipeline_model_parallel_size) - - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - num_data_parallel_groups: int = world_size // data_parallel_size - - if virtual_pipeline_model_parallel_size is not None: - if not pipeline_model_parallel_size > 2: - raise RuntimeError("pipeline-model-parallel size should be greater than 2 with " - "interleaved schedule") - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size - - if pipeline_model_parallel_split_rank is not None: - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + num_tensor_model_parallel_groups: int = (world_size // + tensor_model_parallel_size) + num_pipeline_model_parallel_groups: int = (world_size // + pipeline_model_parallel_size) rank = torch.distributed.get_rank() - # Build the data-parallel groups. - global _DATA_PARALLEL_GROUP - global _DATA_PARALLEL_GLOBAL_RANKS - assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized' - all_data_parallel_group_ranks = [] - for i in range(pipeline_model_parallel_size): - start_rank = i * num_pipeline_model_parallel_groups - end_rank = (i + 1) * num_pipeline_model_parallel_groups - for j in range(tensor_model_parallel_size): - ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) - all_data_parallel_group_ranks.append(list(ranks)) - group = torch.distributed.new_group(ranks) - if rank in ranks: - _DATA_PARALLEL_GROUP = group - _DATA_PARALLEL_GLOBAL_RANKS = ranks - - # Build the model-parallel groups. - global _MODEL_PARALLEL_GROUP - assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized' - for i in range(data_parallel_size): - ranks = [data_parallel_group_ranks[i] - for data_parallel_group_ranks in all_data_parallel_group_ranks] - group = torch.distributed.new_group(ranks) - if rank in ranks: - _MODEL_PARALLEL_GROUP = group - # Build the tensor model-parallel groups. global _TENSOR_MODEL_PARALLEL_GROUP - assert _TENSOR_MODEL_PARALLEL_GROUP is None, \ - 'tensor model parallel group is already initialized' + assert _TENSOR_MODEL_PARALLEL_GROUP is None, ( + "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) @@ -147,268 +70,60 @@ def initialize_model_parallel( if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group - # Build the pipeline model-parallel groups and embedding groups - # (first and last rank in each pipeline model-parallel group). + # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_GLOBAL_RANKS - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ - 'pipeline model parallel group is already initialized' - global _EMBEDDING_GROUP - global _EMBEDDING_GLOBAL_RANKS - assert _EMBEDDING_GROUP is None, 'embedding group is already initialized' - global _POSITION_EMBEDDING_GROUP - global _POSITION_EMBEDDING_GLOBAL_RANKS - assert _POSITION_EMBEDDING_GROUP is None, \ - 'position embedding group is already initialized' + assert _PIPELINE_MODEL_PARALLEL_GROUP is None, ( + "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) group = torch.distributed.new_group(ranks) if rank in ranks: _PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_GLOBAL_RANKS = ranks - # Setup embedding group (to exchange gradients between - # first and last stages). - if len(ranks) > 1: - embedding_ranks = [ranks[0], ranks[-1]] - position_embedding_ranks = [ranks[0]] - if pipeline_model_parallel_split_rank is not None: - if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: - embedding_ranks = [ranks[0], - ranks[pipeline_model_parallel_split_rank], - ranks[-1]] - if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks: - position_embedding_ranks = [ranks[0], - ranks[pipeline_model_parallel_split_rank]] - else: - embedding_ranks = ranks - position_embedding_ranks = ranks - group = torch.distributed.new_group(embedding_ranks) - if rank in embedding_ranks: - _EMBEDDING_GROUP = group - if rank in ranks: - _EMBEDDING_GLOBAL_RANKS = embedding_ranks - - group = torch.distributed.new_group(position_embedding_ranks) - if rank in position_embedding_ranks: - _POSITION_EMBEDDING_GROUP = group - if rank in ranks: - _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks def model_parallel_is_initialized(): """Check if model and data parallel groups are initialized.""" - if _TENSOR_MODEL_PARALLEL_GROUP is None or \ - _PIPELINE_MODEL_PARALLEL_GROUP is None or \ - _DATA_PARALLEL_GROUP is None: - return False - return True - - -def get_model_parallel_group(): - """Get the model parallel group the caller rank belongs to.""" - assert _MODEL_PARALLEL_GROUP is not None, \ - 'model parallel group is not initialized' - return _MODEL_PARALLEL_GROUP + return (_TENSOR_MODEL_PARALLEL_GROUP is not None + and _PIPELINE_MODEL_PARALLEL_GROUP is not None) def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" - assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \ - 'intra_layer_model parallel group is not initialized' + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( + "tenosr model parallel group is not initialized") return _TENSOR_MODEL_PARALLEL_GROUP def get_pipeline_model_parallel_group(): """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \ - 'pipeline_model parallel group is not initialized' + assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, ( + "pipeline model parallel group is not initialized") return _PIPELINE_MODEL_PARALLEL_GROUP -def get_data_parallel_group(): - """Get the data parallel group the caller rank belongs to.""" - assert _DATA_PARALLEL_GROUP is not None, \ - 'data parallel group is not initialized' - return _DATA_PARALLEL_GROUP - - -def get_embedding_group(): - """Get the embedding group the caller rank belongs to.""" - assert _EMBEDDING_GROUP is not None, \ - 'embedding group is not initialized' - return _EMBEDDING_GROUP - - -def get_position_embedding_group(): - """Get the position embedding group the caller rank belongs to.""" - assert _POSITION_EMBEDDING_GROUP is not None, \ - 'position embedding group is not initialized' - return _POSITION_EMBEDDING_GROUP - - -def set_tensor_model_parallel_world_size(world_size): - """Set the tensor model parallel size""" - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size - - -def set_pipeline_model_parallel_world_size(world_size): - """Set the pipeline model parallel size""" - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size - - def get_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: - return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + return torch.distributed.get_world_size( + group=get_tensor_model_parallel_group()) def get_pipeline_model_parallel_world_size(): """Return world size for the pipeline model parallel group.""" - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: - return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) - - -def set_tensor_model_parallel_rank(rank): - """Set tensor model parallel rank.""" - global _MPU_TENSOR_MODEL_PARALLEL_RANK - _MPU_TENSOR_MODEL_PARALLEL_RANK = rank - - -def set_pipeline_model_parallel_rank(rank): - """Set pipeline model parallel rank.""" - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank - - -def set_pipeline_model_parallel_split_rank(rank): - """Set pipeline model parallel split rank.""" - global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK - _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank + return torch.distributed.get_world_size( + group=get_pipeline_model_parallel_group()) def get_tensor_model_parallel_rank(): """Return my rank for the tensor model parallel group.""" - global _MPU_TENSOR_MODEL_PARALLEL_RANK - if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: - return _MPU_TENSOR_MODEL_PARALLEL_RANK return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) def get_pipeline_model_parallel_rank(): """Return my rank for the pipeline model parallel group.""" - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: - return _MPU_PIPELINE_MODEL_PARALLEL_RANK - return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) - - - -def is_pipeline_first_stage(ignore_virtual=False): - """Return True if in the first pipeline model-parallel stage, False otherwise.""" - if not ignore_virtual: - if get_virtual_pipeline_model_parallel_world_size() is not None and \ - get_virtual_pipeline_model_parallel_rank() != 0: - return False - return get_pipeline_model_parallel_rank() == 0 - - -def is_pipeline_last_stage(ignore_virtual=False): - """Return True if in the last pipeline model-parallel stage, False otherwise.""" - if not ignore_virtual: - virtual_pipeline_model_parallel_world_size = \ - get_virtual_pipeline_model_parallel_world_size() - if virtual_pipeline_model_parallel_world_size is not None and \ - get_virtual_pipeline_model_parallel_rank() != ( - virtual_pipeline_model_parallel_world_size - 1): - return False - return get_pipeline_model_parallel_rank() == ( - get_pipeline_model_parallel_world_size() - 1) - - -def is_rank_in_embedding_group(ignore_virtual=False): - """Return true if current rank is in embedding group, False otherwise.""" - rank = torch.distributed.get_rank() - global _EMBEDDING_GLOBAL_RANKS - if ignore_virtual: - return rank in _EMBEDDING_GLOBAL_RANKS - if rank in _EMBEDDING_GLOBAL_RANKS: - if rank == _EMBEDDING_GLOBAL_RANKS[0]: - return is_pipeline_first_stage(ignore_virtual=False) - elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: - return is_pipeline_last_stage(ignore_virtual=False) - else: - return True - return False - - -def is_rank_in_position_embedding_group(): - """Return true if current rank is in position embedding group, False otherwise.""" - rank = torch.distributed.get_rank() - global _POSITION_EMBEDDING_GLOBAL_RANKS - return rank in _POSITION_EMBEDDING_GLOBAL_RANKS - - -def is_pipeline_stage_before_split(rank=None): - """Return True if pipeline stage executes encoder block for a model - with both encoder and decoder.""" - if get_pipeline_model_parallel_world_size() == 1: - return True - if rank is None: - rank = get_pipeline_model_parallel_rank() - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: - return True - if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: - return True - return False - - -def is_pipeline_stage_after_split(rank=None): - """Return True if pipeline stage executes decoder block for a model - with both encoder and decoder.""" - if get_pipeline_model_parallel_world_size() == 1: - return True - if rank is None: - rank = get_pipeline_model_parallel_rank() - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: - return True - if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: - return True - return False - - -def is_pipeline_stage_at_split(): - """Return true if pipeline stage executes decoder block and next - stage executes encoder block for a model with both encoder and - decoder.""" - rank = get_pipeline_model_parallel_rank() - return is_pipeline_stage_before_split(rank) and \ - is_pipeline_stage_after_split(rank+1) - - -def get_virtual_pipeline_model_parallel_rank(): - """Return the virtual pipeline-parallel rank.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - - -def set_virtual_pipeline_model_parallel_rank(rank): - """Set the virtual pipeline-parallel rank.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank - - -def get_virtual_pipeline_model_parallel_world_size(): - """Return the virtual pipeline-parallel world size.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_rank( + group=get_pipeline_model_parallel_group()) def get_tensor_model_parallel_src_rank(): @@ -419,35 +134,27 @@ def get_tensor_model_parallel_src_rank(): return (global_rank // local_world_size) * local_world_size -def get_data_parallel_src_rank(): - """Calculate the global rank corresponding to the first local rank - in the data parallel group.""" - assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \ - "Data parallel group is not initialized" - return _DATA_PARALLEL_GLOBAL_RANKS[0] - - def get_pipeline_model_parallel_first_rank(): """Return the global rank of the first process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, \ - "Pipeline parallel group is not initialized" + assert _PIPELINE_GLOBAL_RANKS is not None, ( + "Pipeline parallel group is not initialized") return _PIPELINE_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): """Return the global rank of the last process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, \ - "Pipeline parallel group is not initialized" + assert _PIPELINE_GLOBAL_RANKS is not None, ( + "Pipeline parallel group is not initialized") last_rank_local = get_pipeline_model_parallel_world_size() - 1 return _PIPELINE_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): """Return the global rank that follows the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, \ - "Pipeline parallel group is not initialized" + assert _PIPELINE_GLOBAL_RANKS is not None, ( + "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] @@ -455,45 +162,18 @@ def get_pipeline_model_parallel_next_rank(): def get_pipeline_model_parallel_prev_rank(): """Return the global rank that preceeds the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, \ - "Pipeline parallel group is not initialized" + assert _PIPELINE_GLOBAL_RANKS is not None, ( + "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] -def get_data_parallel_world_size(): - """Return world size for the data parallel group.""" - return torch.distributed.get_world_size(group=get_data_parallel_group()) - - -def get_data_parallel_rank(): - """Return my rank for the data parallel group.""" - return torch.distributed.get_rank(group=get_data_parallel_group()) - def destroy_model_parallel(): """Set the groups to none.""" - global _MODEL_PARALLEL_GROUP - _MODEL_PARALLEL_GROUP = None global _TENSOR_MODEL_PARALLEL_GROUP _TENSOR_MODEL_PARALLEL_GROUP = None global _PIPELINE_MODEL_PARALLEL_GROUP _PIPELINE_MODEL_PARALLEL_GROUP = None - global _DATA_PARALLEL_GROUP - _DATA_PARALLEL_GROUP = None - global _EMBEDDING_GROUP - _EMBEDDING_GROUP = None - global _POSITION_EMBEDDING_GROUP - _POSITION_EMBEDDING_GROUP = None - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_TENSOR_MODEL_PARALLEL_RANK - _MPU_TENSOR_MODEL_PARALLEL_RANK = None - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - _MPU_PIPELINE_MODEL_PARALLEL_RANK = None + global _PIPELINE_GLOBAL_RANKS + _PIPELINE_GLOBAL_RANKS = None diff --git a/vllm/model_executor/parallel_utils/tensor_parallel/__init__.py b/vllm/model_executor/parallel_utils/tensor_parallel/__init__.py deleted file mode 100644 index d17f12f3adabb..0000000000000 --- a/vllm/model_executor/parallel_utils/tensor_parallel/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -from .layers import ( - ColumnParallelLinear, - RowParallelLinear, - VocabParallelEmbedding, - set_tensor_model_parallel_attributes, - set_defaults_if_not_set_tensor_model_parallel_attributes, - copy_tensor_model_parallel_attributes, - param_is_not_tensor_parallel_duplicate, -) - -from .mappings import ( - copy_to_tensor_model_parallel_region, - gather_from_tensor_model_parallel_region, - gather_from_sequence_parallel_region, - reduce_from_tensor_model_parallel_region, - scatter_to_tensor_model_parallel_region, - scatter_to_sequence_parallel_region, -) - -from .random import ( - get_cuda_rng_tracker, - model_parallel_cuda_manual_seed, -) - -from .utils import ( - split_tensor_along_last_dim, -) - -__all__ = [ - #layers.py - "ColumnParallelLinear", - "RowParallelLinear", - "VocabParallelEmbedding", - "set_tensor_model_parallel_attributes", - "set_defaults_if_not_set_tensor_model_parallel_attributes", - "copy_tensor_model_parallel_attributes", - "param_is_not_tensor_parallel_duplicate", - # mappings.py - "copy_to_tensor_model_parallel_region", - "gather_from_tensor_model_parallel_region", - "gather_from_sequence_parallel_region", - "reduce_from_tensor_model_parallel_region", - "scatter_to_tensor_model_parallel_region", - "scatter_to_sequence_parallel_region", - # random.py - "get_cuda_rng_tracker", - "model_parallel_cuda_manual_seed", - # utils.py - "split_tensor_along_last_dim", -] diff --git a/vllm/model_executor/parallel_utils/tensor_parallel/layers.py b/vllm/model_executor/parallel_utils/tensor_parallel/layers.py deleted file mode 100644 index bfaf9c5f7349c..0000000000000 --- a/vllm/model_executor/parallel_utils/tensor_parallel/layers.py +++ /dev/null @@ -1,366 +0,0 @@ -# Copyright 2023 The vLLM team. -# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch -from typing import Optional - -import torch -import torch.nn.functional as F -import torch.nn.init as init -from torch.nn.parameter import Parameter - -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from .mappings import ( - gather_from_tensor_model_parallel_region, - reduce_from_tensor_model_parallel_region, - scatter_to_tensor_model_parallel_region, -) - -from .utils import ( - divide, - VocabUtility, -) - -_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, - 'partition_dim': -1, - 'partition_stride': 1} - -def param_is_not_tensor_parallel_duplicate(param): - return (hasattr(param, 'tensor_model_parallel') and - param.tensor_model_parallel) or ( - get_tensor_model_parallel_rank() == 0) - - -def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): - # Make sure the attributes are not set. - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - assert not hasattr(tensor, attribute) - # Set the attributes. - setattr(tensor, 'tensor_model_parallel', is_parallel) - setattr(tensor, 'partition_dim', dim) - setattr(tensor, 'partition_stride', stride) - - -def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): - def maybe_set(attribute, value): - if not hasattr(tensor, attribute): - setattr(tensor, attribute, value) - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) - - -def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): - def maybe_copy(attribute): - if hasattr(source_tensor, attribute): - setattr(destination_tensor, attribute, - getattr(source_tensor, attribute)) - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - maybe_copy(attribute) - - -class VocabParallelEmbedding(torch.nn.Module): - """Embedding parallelized in the vocabulary dimension. - - This is mainly adapted from torch.nn.Embedding and all the default - values are kept. - Arguments: - num_embeddings: vocabulary size. - embedding_dim: size of hidden state. - - Keyword Arguments: - init_method: method to initialize weights. - params_dtype - use_cpu_initialization - perform_initialization - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, *, - init_method=init.xavier_normal_, - params_dtype: torch.dtype=None, - use_cpu_initialization: bool=False, - perform_initialization: bool=False): - super(VocabParallelEmbedding, self).__init__() - assert not perform_initialization - assert not use_cpu_initialization - - # Keep the input dimensions. - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - # Set the defaults for compatibility. - self.padding_idx = None - self.max_norm = None - self.norm_type = 2. - self.scale_grad_by_freq = False - self.sparse = False - self._weight = None - self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() - # Divide the weight matrix along the vocaburaly dimension. - self.vocab_start_index, self.vocab_end_index = \ - VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, get_tensor_model_parallel_rank(), - self.tensor_model_parallel_size) - self.num_embeddings_per_partition = self.vocab_end_index - \ - self.vocab_start_index - - self.weight = Parameter(torch.empty( - self.num_embeddings_per_partition, self.embedding_dim, - device=torch.cuda.current_device(), dtype=params_dtype)) - - def forward(self, input_): - if self.tensor_model_parallel_size > 1: - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | \ - (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - else: - masked_input = input_ - # Get the embeddings. - output_parallel = F.embedding(masked_input, self.weight, - self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, - self.sparse) - # Mask the output embedding. - if self.tensor_model_parallel_size > 1: - output_parallel[input_mask, :] = 0.0 - # Reduce across all the model parallel GPUs. - output = reduce_from_tensor_model_parallel_region(output_parallel) - return output - - -class ColumnParallelLinear(torch.nn.Module): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - - Keyword Arguments - bias: If true, add bias - gather_output: If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is Y_i = XA_i - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - skip_bias_add: This was added to enable performance optimations where bias - can be fused with other elementwise operations. we skip - adding bias but instead return it. - params_dtype: - use_cpu_initialization: - """ - - def __init__(self, input_size, output_size, *, - bias=True, gather_output=True, - init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - params_dtype=None, - use_cpu_initialization=False, - perform_initialization=False, - quant_config=None, - ): - super(ColumnParallelLinear, self).__init__() - assert not perform_initialization - assert not use_cpu_initialization - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - self.world_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, self.world_size) - self.skip_bias_add = skip_bias_add - self.quant_config = quant_config - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - self.create_weights(params_dtype) - - if bias: - self.bias = Parameter(torch.empty( - self.output_size_per_partition, - device=torch.cuda.current_device(), - dtype=params_dtype)) - set_tensor_model_parallel_attributes(self.bias, True, 0, stride) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter('bias', None) - - def create_weights(self, dtype: torch.dtype) -> None: - self.weight = Parameter(torch.empty( - self.output_size_per_partition, self.input_size, - device=torch.cuda.current_device(), dtype=dtype)) - - def apply_weights( - self, - x: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - return F.linear(x, self.weight, bias) - - def forward(self, input_): - """Forward of ColumnParallelLinear - - Args: - input_: 3D tensor whose order of dimension is [sequence, batch, hidden] - - Returns: - - output - - bias - """ - bias = self.bias if not self.skip_bias_add else None - - input_parallel = input_ - # Matrix multiply. - output_parallel = self.apply_weights(input_parallel, bias) - if self.gather_output: - # All-gather across the partitions. - output = gather_from_tensor_model_parallel_region(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - -class RowParallelLinear(torch.nn.Module): - """Linear layer with row parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: - - - - | A_1 | - | . | - A = | . | X = [X_1, ..., X_p] - | . | - | A_p | - - - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - - Keyword Arguments: - bias: If true, add bias. Note that bias is not parallelized. - input_is_parallel: If true, we assume that the input is already - split across the GPUs and we do not split - again. - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - skip_bias_add: This was added to enable performance optimization where bias - can be fused with other elementwise operations. We skip - adding bias but instead return it. - params_dtype: - use_cpu_initialization: - perform_initialization: - reduce_results: - """ - - def __init__(self, input_size, output_size, *, - bias=True, input_is_parallel=False, - init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - params_dtype=None, - use_cpu_initialization=False, - perform_initialization=False, - reduce_results=True, - quant_config=None, - ): - super(RowParallelLinear, self).__init__() - assert not perform_initialization - assert not use_cpu_initialization - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.input_is_parallel = input_is_parallel - self.reduce_results = reduce_results - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - # Divide the weight matrix along the last dimension. - self.world_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, self.world_size) - self.skip_bias_add = skip_bias_add - self.quant_config = quant_config - - self.create_weights(params_dtype) - - if not reduce_results and (bias and not skip_bias_add): - raise ValueError("When not reduce the results, adding bias to the " - "results can lead to incorrect results") - - if bias: - self.bias = Parameter(torch.empty( - self.output_size, device=torch.cuda.current_device(), - dtype=params_dtype)) - - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter('bias', None) - - def create_weights(self, dtype: torch.dtype) -> None: - self.weight = Parameter(torch.empty( - self.output_size, self.input_size_per_partition, - device=torch.cuda.current_device(), dtype=dtype)) - - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - return F.linear(x, self.weight) - - def forward(self, input_): - """Forward of RowParallelLinear - - Args: - input_: 3D tensor whose order of dimension is [sequence, batch, hidden] - - Returns: - - output - - bias - """ - # Set up backprop all-reduce. - if self.input_is_parallel: - input_parallel = input_ - else: - input_parallel = scatter_to_tensor_model_parallel_region(input_) - # Matrix multiply. - output_parallel = self.apply_weights(input_parallel) - if self.reduce_results and self.world_size > 1: - output_ = reduce_from_tensor_model_parallel_region(output_parallel) - else: - output_ = output_parallel - - if not self.skip_bias_add: - output = output_ + self.bias if self.bias is not None else output_ - output_bias = None - else: - output = output_ - output_bias = self.bias - return output, output_bias diff --git a/vllm/model_executor/parallel_utils/tensor_parallel/mappings.py b/vllm/model_executor/parallel_utils/tensor_parallel/mappings.py deleted file mode 100644 index 62d6403728071..0000000000000 --- a/vllm/model_executor/parallel_utils/tensor_parallel/mappings.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright 2023 The vLLM team. -# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import torch - -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - get_tensor_model_parallel_group, -) -from .utils import split_tensor_along_last_dim - - -def _reduce(input_): - """All-reduce the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if get_tensor_model_parallel_world_size()==1: - return input_ - - # All-reduce. - torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) - - return input_ - - -def _split_along_last_dim(input_): - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - # Split along last dimension. - input_list = split_tensor_along_last_dim(input_, world_size) - - # Note: torch.split does not create contiguous tensors by default. - rank = get_tensor_model_parallel_rank() - output = input_list[rank].contiguous() - - return output - - -def _split_along_first_dim(input_): - """Split the tensor along its first dimension and keep the - corresponding slice.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - # Split along first dimension. - dim_size = input_.size()[0] - assert dim_size % world_size == 0, \ - "First dimension of the tensor should be divisible by tensor parallel size" - local_dim_size = dim_size // world_size - rank = get_tensor_model_parallel_rank() - dim_offset = rank * local_dim_size - - output = input_[dim_offset:dim_offset+local_dim_size].contiguous() - - return output - - -def _gather_along_last_dim(input_): - """Gather tensors and concatinate along the last dimension.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - # Size and dimension. - last_dim = input_.dim() - 1 - rank = get_tensor_model_parallel_rank() - - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) - - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=last_dim).contiguous() - - return output - - -def _gather_along_first_dim(input_): - """Gather tensors and concatinate along the first dimension.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - dim_size[0] = dim_size[0] * world_size - - output = torch.empty(dim_size, dtype=input_.dtype, - device=torch.cuda.current_device()) - torch.distributed._all_gather_base(output, input_.contiguous(), - group=get_tensor_model_parallel_group()) - - return output - -def _reduce_scatter_along_first_dim(input_): - """Reduce-scatter the input tensor across model parallel group.""" - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - assert dim_size[0] % world_size == 0, \ - "First dimension of the tensor should be divisible by tensor parallel size" - - dim_size[0] = dim_size[0] // world_size - - output = torch.empty(dim_size, dtype=input_.dtype, - device=torch.cuda.current_device()) - torch.distributed._reduce_scatter_base(output, input_.contiguous(), - group=get_tensor_model_parallel_group()) - return output - - -class _CopyToModelParallelRegion(torch.autograd.Function): - """Pass the input to the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return input_ - - @staticmethod - def forward(ctx, input_): - return input_ - - @staticmethod - def backward(ctx, grad_output): - return _reduce(grad_output) - - -class _ReduceFromModelParallelRegion(torch.autograd.Function): - """All-reduce the input from the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return _reduce(input_) - - @staticmethod - def forward(ctx, input_): - return _reduce(input_) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class _ScatterToModelParallelRegion(torch.autograd.Function): - """Split the input and keep only the corresponding chuck to the rank.""" - - @staticmethod - def symbolic(graph, input_): - return _split_along_last_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _split_along_last_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_last_dim(grad_output) - - -class _GatherFromModelParallelRegion(torch.autograd.Function): - """Gather the input from model parallel region and concatinate.""" - - @staticmethod - def symbolic(graph, input_): - return _gather_along_last_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _gather_along_last_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _split_along_last_dim(grad_output) - - -class _ScatterToSequenceParallelRegion(torch.autograd.Function): - """Split the input and keep only the corresponding chuck to the rank.""" - - @staticmethod - def symbolic(graph, input_): - return _split_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _split_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim(grad_output) - - -class _GatherFromSequenceParallelRegion(torch.autograd.Function): - """Gather the input from sequence parallel region and concatinate.""" - - @staticmethod - def symbolic(graph, input_, tensor_parallel_output_grad=True): - return _gather_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_, tensor_parallel_output_grad=True): - ctx.tensor_parallel_output_grad = tensor_parallel_output_grad - return _gather_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - tensor_parallel_output_grad = ctx.tensor_parallel_output_grad - - # If the computation graph after the gather operation is - # in the tensor parallel mode, output gradients need to reduce - # scattered and whereas if the computation is duplicated, - # output gradients need to be scattered. - if tensor_parallel_output_grad: - return _reduce_scatter_along_first_dim(grad_output), None - else: - return _split_along_first_dim(grad_output), None - - -class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): - """Reduce scatter the input from the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return _reduce_scatter_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _reduce_scatter_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim(grad_output) - - -# ----------------- -# Helper functions. -# ----------------- - -def copy_to_tensor_model_parallel_region(input_): - return _CopyToModelParallelRegion.apply(input_) - - -def reduce_from_tensor_model_parallel_region(input_): - return _ReduceFromModelParallelRegion.apply(input_) - - -def scatter_to_tensor_model_parallel_region(input_): - return _ScatterToModelParallelRegion.apply(input_) - - -def gather_from_tensor_model_parallel_region(input_): - return _GatherFromModelParallelRegion.apply(input_) - - -def scatter_to_sequence_parallel_region(input_): - return _ScatterToSequenceParallelRegion.apply(input_) - - -def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): - return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) - - -def reduce_scatter_to_sequence_parallel_region(input_): - return _ReduceScatterToSequenceParallelRegion.apply(input_) - diff --git a/vllm/model_executor/parallel_utils/tensor_parallel/random.py b/vllm/model_executor/parallel_utils/tensor_parallel/random.py deleted file mode 100644 index 958e842114433..0000000000000 --- a/vllm/model_executor/parallel_utils/tensor_parallel/random.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2023 The vLLM team. -# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch - -import contextlib - -import torch -from torch import _C -from torch.cuda import _lazy_call, device as device_ctx_manager - -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, -) - -# Default name for the model parallel rng tracker. -_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' - - -def _set_cuda_rng_state(new_state, device=-1): - """Sets the random number generator state of the current GPU. - - Argumentss: - new_state (torch.ByteTensor): The desired state - This function is adapted from PyTorch repo (torch.cuda.set_rng_state) - with a single change: the input state is not cloned. Cloning caused - major performance issues for +4 GPU cases. - """ - if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): - # older PyTorch - def cb(): - with device_ctx_manager(device): - _C._cuda_setRNGState(new_state) - else: - # newer PyTorch - if device == -1: - device = torch.device('cuda') - elif isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device('cuda', device) - - def cb(): - idx = device.index - if idx is None: - idx = torch.cuda.current_device() - default_generator = torch.cuda.default_generators[idx] - default_generator.set_state(new_state) - - _lazy_call(cb) - - - -class CudaRNGStatesTracker: - """Tracker for the cuda RNG states. - - Using the `add` method, a cuda rng state is initialized based on - the input `seed` and is assigned to `name`. Later, by forking the - rng state, we can perform operations and return to our starting - cuda state. - """ - - def __init__(self): - # Map from a string name to the cuda rng state. - self.states_ = {} - # Seeds are just for book keeping and ensure no seed is set twice. - self.seeds_ = set() - - def reset(self): - """Set to the initial state (no tracker).""" - self.states_ = {} - self.seeds_ = set() - - def get_states(self): - """Get rng states. Copy the dictionary so we have direct - pointers to the states, not just a pointer to the dictionary.""" - states = {} - for name in self.states_: - states[name] = self.states_[name] - return states - - def set_states(self, states): - """Set the rng states. For efficiency purposes, we do not check - the size of seed for compatibility.""" - self.states_ = states - - def add(self, name, seed): - """Track the rng state.""" - # Check seed is not already used. - if seed in self.seeds_: - raise Exception('seed {} already exists'.format(seed)) - self.seeds_.add(seed) - # Check that state is not already defined. - if name in self.states_: - raise Exception('cuda rng state {} already exists'.format(name)) - # Get the current rng state. - orig_rng_state = torch.cuda.get_rng_state() - # Set the new state and store it. - torch.cuda.manual_seed(seed) - self.states_[name] = torch.cuda.get_rng_state() - # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) - - @contextlib.contextmanager - def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - """Fork the cuda rng state, perform operations, and exit with - the original state.""" - # Check if we have added the state - if name not in self.states_: - raise Exception('cuda rng state {} is not added'.format(name)) - # Store current rng state. - orig_cuda_rng_state = torch.cuda.get_rng_state() - # Set rng state to the desired one - _set_cuda_rng_state(self.states_[name]) - # Do the stuff we wanted to do. - try: - yield - finally: - # Update the current rng state for later use. - self.states_[name] = torch.cuda.get_rng_state() - # And set the state to the original state we started with. - _set_cuda_rng_state(orig_cuda_rng_state) - - -# RNG tracker object. -_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - - -def get_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _CUDA_RNG_STATE_TRACKER - - -def model_parallel_cuda_manual_seed(seed): - """Initialize model parallel cuda seed. - - This function should be called after the model parallel is - initialized. Also, no torch.cuda.manual_seed should be called - after this function. Basically, this is replacement for that - function. - Two set of RNG states are tracked: - default state: This is for data parallelism and is the same among a - set of model parallel GPUs but different across - different model paralle groups. This is used for - example for dropout in the non-tensor-model-parallel regions. - tensor-model-parallel state: This state is different among a set of model - parallel GPUs, but the same across data parallel - groups. This is used for example for dropout in - model parallel regions. - """ - # 2718 is just for fun and any POSITIVE value will work. - offset = seed + 2718 - tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() - # Data parallel gets the original seed. - data_parallel_seed = seed - - _CUDA_RNG_STATE_TRACKER.reset() - # Set the default state. - torch.cuda.manual_seed(data_parallel_seed) - # and model parallel state. - _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, - tensor_model_parallel_seed) diff --git a/vllm/model_executor/parallel_utils/tensor_parallel/utils.py b/vllm/model_executor/parallel_utils/utils.py similarity index 86% rename from vllm/model_executor/parallel_utils/tensor_parallel/utils.py rename to vllm/model_executor/parallel_utils/utils.py index 2efd123f5667c..5d0a7595a4a2c 100644 --- a/vllm/model_executor/parallel_utils/tensor_parallel/utils.py +++ b/vllm/model_executor/parallel_utils/utils.py @@ -1,15 +1,16 @@ # Copyright 2023 The vLLM team. -# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from typing import List, Sequence import torch -from typing import List, Sequence + def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format( - numerator, denominator - ) + numerator, denominator) def divide(numerator, denominator): @@ -56,15 +57,14 @@ class VocabUtility: @staticmethod def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, rank, world_size: int - ) -> Sequence[int]: + per_partition_vocab_size: int, rank: int) -> Sequence[int]: index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size return index_f, index_l @staticmethod - def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]: + def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, + world_size: int) -> Sequence[int]: per_partition_vocab_size = divide(global_vocab_size, world_size) return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank, world_size - ) + per_partition_vocab_size, rank) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 739eac7d2de0e..bd74ae96aa19e 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -4,9 +4,6 @@ import random import numpy as np import torch -from vllm.model_executor.parallel_utils.parallel_state import model_parallel_is_initialized -from vllm.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed - def set_random_seed(seed: int) -> None: random.seed(seed) @@ -14,6 +11,3 @@ def set_random_seed(seed: int) -> None: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - - if model_parallel_is_initialized(): - model_parallel_cuda_manual_seed(seed)