diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index a05ae0edbd775..366dfd97d8163 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -452,8 +452,10 @@ class ColumnParallelLinear(LinearBase): else: self.register_parameter("bias", None) + self.tp_rank = get_tensor_model_parallel_rank() + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) is_sharded_weight = getattr(param, "is_sharded_weight", False) @@ -472,15 +474,15 @@ class ColumnParallelLinear(LinearBase): if is_gguf_weight and isinstance(param, UninitializedParameter): final_shape = list(loaded_weight.shape) if output_dim is not None: - tp_size = get_tensor_model_parallel_world_size() - assert final_shape[output_dim] % tp_size == 0 - final_shape[output_dim] = final_shape[output_dim] // tp_size + assert final_shape[output_dim] % self.tp_size == 0 + final_shape[output_dim] = (final_shape[output_dim] // + self.tp_size) param.materialize(final_shape, dtype=loaded_weight.dtype) param_data = param.data if output_dim is not None and not is_sharded_weight: shard_size = param_data.shape[output_dim] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -565,8 +567,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): return_bias: bool = True, ): self.output_sizes = output_sizes - tp_size = get_tensor_model_parallel_world_size() - assert all(output_size % tp_size == 0 for output_size in output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + assert all(output_size % self.tp_size == 0 + for output_size in output_sizes) super().__init__(input_size=input_size, output_size=sum(output_sizes), bias=bias, @@ -598,12 +603,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size if loaded_shard_id is not None: loaded_weight = loaded_weight.narrow(output_dim, start_idx, @@ -669,11 +672,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): return assert loaded_shard_id < len(self.output_sizes) - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() if output_dim is not None: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = (sum(self.output_sizes[:loaded_shard_id]) // + self.tp_size) + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. @@ -701,7 +703,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param_data = param_data.narrow(output_dim, shard_offset, shard_size) - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size if not is_sharded_weight: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -991,12 +993,9 @@ class QKVParallelLinear(ColumnParallelLinear): return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size if loaded_shard_id is not None: loaded_weight = loaded_weight.narrow(output_dim, start_idx, @@ -1071,7 +1070,6 @@ class QKVParallelLinear(ColumnParallelLinear): self.weight_loader(param, loaded_weight_shard, shard_id) return - tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] # If output dim is defined, use the default loading process. @@ -1123,9 +1121,9 @@ class QKVParallelLinear(ColumnParallelLinear): param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": - shard_id = tp_rank + shard_id = self.tp_rank else: - shard_id = tp_rank // self.num_kv_head_replicas + shard_id = self.tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size if not is_sharded_weight: @@ -1245,8 +1243,6 @@ class RowParallelLinear(LinearBase): self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) is_sharded_weight = getattr(param, "is_sharded_weight", False) @@ -1264,13 +1260,14 @@ class RowParallelLinear(LinearBase): if is_gguf_weight and isinstance(param, UninitializedParameter): weight_shape = list(loaded_weight.shape) if input_dim: - weight_shape[input_dim] = weight_shape[input_dim] // tp_size + weight_shape[input_dim] = (weight_shape[input_dim] // + self.tp_size) param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data if input_dim is not None and not is_sharded_weight: shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)