# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Parts of the code here are adapted from PyTorch # repo: https://github.com/pytorch/pytorch import math import os from typing import Optional import warnings import torch import torch.nn.functional as F import torch.nn.init as init from torch.nn.parameter import Parameter from cacheflow.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, get_global_memory_buffer, ) from .mappings import ( copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region, gather_from_sequence_parallel_region, reduce_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region, reduce_scatter_to_sequence_parallel_region, ) from .random import get_cuda_rng_tracker from .utils import ( divide, split_tensor_along_last_dim, VocabUtility, ) _grad_accum_fusion_available = True try: import fused_weight_gradient_mlp_cuda except ImportError: _grad_accum_fusion_available = False _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, 'partition_dim': -1, 'partition_stride': 1} def param_is_not_tensor_parallel_duplicate(param): return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( get_tensor_model_parallel_rank() == 0) def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): # Make sure the attributes are not set. for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: assert not hasattr(tensor, attribute) # Set the attributes. setattr(tensor, 'tensor_model_parallel', is_parallel) setattr(tensor, 'partition_dim', dim) setattr(tensor, 'partition_stride', stride) def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): def maybe_set(attribute, value): if not hasattr(tensor, attribute): setattr(tensor, attribute, value) for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): def maybe_copy(attribute): if hasattr(source_tensor, attribute): setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: maybe_copy(attribute) def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): """Initialize affine weight for model parallel on GPU.""" set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride) with get_cuda_rng_tracker().fork(): init_method(weight) def _initialize_affine_weight_cpu(weight, output_size, input_size, per_partition_size, partition_dim, init_method, stride=1, return_master_weight=False, *, params_dtype=None): """Initialize affine weight for model parallel. Build the master weight on all processes and scatter the relevant chunk.""" set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride) if params_dtype is None: params_dtype = torch.get_default_dtype() # Initialize master weight master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False) init_method(master_weight) master_weight = master_weight.to(dtype=params_dtype) # Split and copy per_partition_per_stride_size = divide(per_partition_size, stride) weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() my_weight_list = weight_list[rank::world_size] with torch.no_grad(): torch.cat(my_weight_list, dim=partition_dim, out=weight) if return_master_weight: return master_weight return None class VocabParallelEmbedding(torch.nn.Module): """Embedding parallelized in the vocabulary dimension. This is mainly adapted from torch.nn.Embedding and all the default values are kept. Arguments: num_embeddings: vocabulary size. embedding_dim: size of hidden state. Keyword Arguments: init_method: method to initialize weights. params_dtype use_cpu_initialization perform_initialization """ def __init__(self, num_embeddings: int, embedding_dim: int, *, init_method=init.xavier_normal_, params_dtype: torch.dtype=None, use_cpu_initialization: bool=False, perform_initialization: bool=True): super(VocabParallelEmbedding, self).__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if params_dtype is None: params_dtype = torch.get_default_dtype() # Set the defaults for compatibility. self.padding_idx = None self.max_norm = None self.norm_type = 2. self.scale_grad_by_freq = False self.sparse = False self._weight = None self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() # Divide the weight matrix along the vocaburaly dimension. self.vocab_start_index, self.vocab_end_index = \ VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size) self.num_embeddings_per_partition = self.vocab_end_index - \ self.vocab_start_index # Allocate weights and initialize. if use_cpu_initialization: self.weight = Parameter(torch.empty( self.num_embeddings_per_partition, self.embedding_dim, dtype=params_dtype)) if perform_initialization: _initialize_affine_weight_cpu( self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method, params_dtype=params_dtype) else: self.weight = Parameter(torch.empty( self.num_embeddings_per_partition, self.embedding_dim, device=torch.cuda.current_device(), dtype=params_dtype)) if perform_initialization: _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1) def forward(self, input_): if self.tensor_model_parallel_size > 1: # Build the mask. input_mask = (input_ < self.vocab_start_index) | \ (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 else: masked_input = input_ # Get the embeddings. output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) # Mask the output embedding. if self.tensor_model_parallel_size > 1: output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_from_tensor_model_parallel_region(output_parallel) return output class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): """See linear_with_grad_accumulation_and_async_allreduce""" @staticmethod def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce, sequence_parallel): ctx.save_for_backward(input, weight) ctx.use_bias = bias is not None ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.async_grad_allreduce = async_grad_allreduce ctx.sequence_parallel = sequence_parallel if sequence_parallel: world_size = get_tensor_model_parallel_world_size() dim_size = list(input.size()) dim_size[0] = dim_size[0] * world_size all_gather_buffer = \ get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") torch.distributed._all_gather_base( all_gather_buffer, input, group=get_tensor_model_parallel_group()) total_input = all_gather_buffer else: total_input = input output = torch.matmul(total_input, weight.t()) if bias is not None: output = output + bias return output @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors use_bias = ctx.use_bias if ctx.sequence_parallel: world_size = get_tensor_model_parallel_world_size() dim_size = list(input.size()) dim_size[0] = dim_size[0] * world_size all_gather_buffer = \ get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") handle = torch.distributed._all_gather_base( all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the input gradient computation total_input = all_gather_buffer else: total_input = input grad_input = grad_output.matmul(weight) if ctx.sequence_parallel: handle.wait() # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = torch.distributed.all_reduce( grad_input, group=get_tensor_model_parallel_group(), async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # all-reduce is scheduled before the weight gradient computation if ctx.sequence_parallel: assert not ctx.async_grad_allreduce dim_size = list(input.size()) sub_grad_input = torch.empty(dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False) # reduce_scatter handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # reduce scatter is scheduled before the weight gradient computation if ctx.gradient_accumulation_fusion: if weight.main_grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad) elif weight.main_grad.dtype == torch.float16: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad) else: raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") grad_weight = None else: grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.sequence_parallel: handle.wait() return sub_grad_input, grad_weight, grad_bias, None, None, None if ctx.async_grad_allreduce: handle.wait() return grad_input, grad_weight, grad_bias, None, None, None def linear_with_grad_accumulation_and_async_allreduce( input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], gradient_accumulation_fusion: bool, async_grad_allreduce: bool, sequence_parallel_enabled: bool, ) -> torch.Tensor: """Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop. This has the option to accumulate the result of backprop calculation into an existing gradient buffer, preventing the need to do an additional addition kernel after the gradient calculation. Additionally, the tensor parallel all reduce of the input gradients can be done asynchronously with the calculation of the weight gradients. In the case of sequence parallelism, the reduce scatter of the input gradients is done asynchronously with the calcluation of the weight gradients. Use of this module requires that the environment variable CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective operations, noted in the code, that should be scheduled before compute kernels to overlap the communication with the computation, which is necessary for a speedup but not for correctness so that ordering isn't imposed by the scheduler. Setting CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled in the order they are called. Arguments: input (torch.Tensor required): input like torch.nn.functional.linear weight (torch.Tensor required): weight like torch.nn.functional.linear bias (torch.Tensor optional): bias like torch.nn.functional.linear gradient_accumulation_fusion (bool required): Perform the gradient accumulation fusion, requires the custom CUDA extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" " Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion." async_grad_allreduce (bool required): Do the allreduce of input gradients asyncronously with the computation of weight gradients. If sequence_parallel_enabled is True, this must be False, as no all reduce is performed. sequence_parallel_enabled (bool required): Indicates that sequence parallelism is used and thus in the forward pass the input is all gathered, and the backward pass the input gradients are reduce scattered. """ args = [ input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce, sequence_parallel_enabled, ] if not linear_with_grad_accumulation_and_async_allreduce.warned: if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if sequence_parallel_enabled: warnings.warn( "When using sequence parallelism it is recommended to set the " "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " "maximum speedup") linear_with_grad_accumulation_and_async_allreduce.warned = True if async_grad_allreduce: warnings.warn( "When using async grad allreduce it is recommended to set the " "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " "maximum speedup") linear_with_grad_accumulation_and_async_allreduce.warned = True with torch.cuda.amp.autocast(enabled=False): return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) linear_with_grad_accumulation_and_async_allreduce.warned = False class ColumnParallelLinear(torch.nn.Module): """Linear layer with column parallelism. The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, ..., A_p]. Arguments: input_size: first dimension of matrix A. output_size: second dimension of matrix A. Keyword Arguments bias: If true, add bias gather_output: If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i init_method: method to initialize weights. Note that bias is always set to zero. stride: For the strided linear layers. keep_master_weight_for_test: This was added for testing and should be set to False. It returns the master weights used for initialization. skip_bias_add: This was added to enable performance optimations where bias can be fused with other elementwise operations. we skip adding bias but instead return it. async_tensor_model_parallel_allreduce: params_dtype: use_cpu_initialization: gradient_accumulation_fusion: sequence_parallel_enabled: """ def __init__(self, input_size, output_size, *, bias=True, gather_output=True, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, skip_bias_add=False, async_tensor_model_parallel_allreduce=True, params_dtype=None, use_cpu_initialization=False, perform_initialization=True, gradient_accumulation_fusion=False, sequence_parallel_enabled: bool = False, ): super(ColumnParallelLinear, self).__init__() # Keep input parameters self.input_size = input_size self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. world_size = get_tensor_model_parallel_world_size() self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add if params_dtype is None: params_dtype = torch.get_default_dtype() # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. # Initialize weight. if use_cpu_initialization: self.weight = Parameter(torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype)) if perform_initialization: self.master_weight = _initialize_affine_weight_cpu( self.weight, self.output_size, self.input_size, self.output_size_per_partition, 0, init_method, stride=stride, return_master_weight=keep_master_weight_for_test) else: self.weight = Parameter(torch.empty( self.output_size_per_partition, self.input_size, device=torch.cuda.current_device(), dtype=params_dtype)) if perform_initialization: _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=stride) if bias: if use_cpu_initialization: self.bias = Parameter(torch.empty( self.output_size_per_partition, dtype=params_dtype)) else: self.bias = Parameter(torch.empty( self.output_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype)) set_tensor_model_parallel_attributes(self.bias, True, 0, stride) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: self.register_parameter('bias', None) self.async_tensor_model_parallel_allreduce = ( async_tensor_model_parallel_allreduce and world_size > 1) if sequence_parallel_enabled: if world_size <= 1: warnings.warn( f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. " f"Disabling sequence parallel." ) sequence_parallel_enabled = False self.sequence_parallel_enabled = sequence_parallel_enabled if gradient_accumulation_fusion: if not _grad_accum_fusion_available: raise RuntimeError( "ColumnParallelLinear was called with gradient_accumulation_fusion set " "to True but the custom CUDA extension fused_weight_gradient_mlp_cuda " "module is not found. To use gradient_accumulation_fusion you must " "install APEX with --cpp_ext and --cuda_ext. For example: " "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" " "Note that the extension requires CUDA>=11. Otherwise, you must turn off " "gradient accumulation fusion." ) self.gradient_accumulation_fusion = gradient_accumulation_fusion if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled: raise RuntimeError( "`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` " "cannot be enabled at the same time." ) def forward(self, input_): """Forward of ColumnParallelLinear Args: input_: 3D tensor whose order of dimension is [sequence, batch, hidden] Returns: - output - bias """ bias = self.bias if not self.skip_bias_add else None if self.async_tensor_model_parallel_allreduce or \ self.sequence_parallel_enabled: input_parallel = input_ else: input_parallel = copy_to_tensor_model_parallel_region(input_) # Matrix multiply. output_parallel = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, weight=self.weight, bias=bias, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=self.async_tensor_model_parallel_allreduce, sequence_parallel_enabled=self.sequence_parallel_enabled, ) if self.gather_output: # All-gather across the partitions. assert not self.sequence_parallel_enabled output = gather_from_tensor_model_parallel_region(output_parallel) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None return output, output_bias class RowParallelLinear(torch.nn.Module): """Linear layer with row parallelism. The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second dimension as: - - | A_1 | | . | A = | . | X = [X_1, ..., X_p] | . | | A_p | - - Arguments: input_size: first dimension of matrix A. output_size: second dimension of matrix A. Keyword Arguments: bias: If true, add bias. Note that bias is not parallelized. input_is_parallel: If true, we assume that the input is already split across the GPUs and we do not split again. init_method: method to initialize weights. Note that bias is always set to zero. stride: For the strided linear layers. keep_master_weight_for_test: This was added for testing and should be set to False. It returns the master weights used for initialization. skip_bias_add: This was added to enable performance optimization where bias can be fused with other elementwise operations. We skip adding bias but instead return it. params_dtype: use_cpu_initialization: perform_initialization: gradient_accumulation_fusion: sequence_parallel_enabled: """ def __init__(self, input_size, output_size, *, bias=True, input_is_parallel=False, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, skip_bias_add=False, params_dtype=None, use_cpu_initialization=False, perform_initialization=True, gradient_accumulation_fusion=False, sequence_parallel_enabled: bool = False, ): super(RowParallelLinear, self).__init__() # Keep input parameters self.input_size = input_size self.output_size = output_size self.input_is_parallel = input_is_parallel if params_dtype is None: params_dtype = torch.get_default_dtype() # Divide the weight matrix along the last dimension. world_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add self.gradient_accumulation_fusion = gradient_accumulation_fusion self.sequence_parallel_enabled = sequence_parallel_enabled if self.sequence_parallel_enabled and not self.input_is_parallel: raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`") # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. # Initialize weight. if use_cpu_initialization: self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition, dtype=params_dtype)) if perform_initialization: self.master_weight = _initialize_affine_weight_cpu( self.weight, self.output_size, self.input_size, self.input_size_per_partition, 1, init_method, stride=stride, return_master_weight=keep_master_weight_for_test, params_dtype=params_dtype) else: self.weight = Parameter(torch.empty( self.output_size, self.input_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype)) if perform_initialization: _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=1, stride=stride) if bias: if use_cpu_initialization: self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) else: self.bias = Parameter(torch.empty( self.output_size, device=torch.cuda.current_device(), dtype=params_dtype)) setattr(self.bias, 'sequence_parallel', sequence_parallel_enabled) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: self.register_parameter('bias', None) def forward(self, input_): """Forward of RowParallelLinear Args: input_: 3D tensor whose order of dimension is [sequence, batch, hidden] Returns: - output - bias """ # Set up backprop all-reduce. if self.input_is_parallel: input_parallel = input_ else: assert not self.sequence_parallel_enabled input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. output_parallel = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, weight=self.weight, bias=None, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=False, sequence_parallel_enabled=False, ) # All-reduce across all the partitions. if self.sequence_parallel_enabled: output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) else: output_ = reduce_from_tensor_model_parallel_region(output_parallel) if not self.skip_bias_add: output = output_ + self.bias if self.bias is not None else output_ output_bias = None else: output = output_ output_bias = self.bias return output, output_bias