[Core] Support tensor parallelism for GGUF quantization (#7520)

This commit is contained in:
Isotr0py 2024-08-20 05:30:14 +08:00 committed by GitHub
parent 47b65a5508
commit 7601cb044d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 15 deletions

View File

@ -7,6 +7,7 @@ import os
import pytest
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from tests.quantization.utils import is_quant_method_supported
@ -20,7 +21,7 @@ MAX_MODEL_LEN = 1024
MODELS = [
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")),
filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")),
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF",
filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")),
@ -39,22 +40,36 @@ MODELS = [
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1, 2])
def test_models(
num_gpus_available,
vllm_runner,
example_prompts,
model,
dtype: str,
max_tokens: int,
num_logprobs: int,
tp_size: int,
) -> None:
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
original_model, gguf_model = model
tokenizer = AutoTokenizer.from_pretrained(original_model)
messages = [[{
'role': 'user',
'content': prompt
}] for prompt in example_prompts]
example_prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# Run unquantized model.
with vllm_runner(model_name=original_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
tensor_parallel_size=1) as original_model:
tensor_parallel_size=tp_size) as original_model:
original_outputs = original_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)
@ -63,8 +78,7 @@ def test_models(
with vllm_runner(model_name=gguf_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
tensor_parallel_size=1) as gguf_model:
tensor_parallel_size=tp_size) as gguf_model:
gguf_outputs = gguf_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)

View File

@ -507,11 +507,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_shard_id
if is_gguf_weight:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = loaded_weight.shape
param.shard_size[loaded_shard_id] = shard_shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
@ -863,8 +868,13 @@ class QKVParallelLinear(ColumnParallelLinear):
param, orig_qkv_offsets, loaded_shard_id)
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = loaded_weight.shape
param.shard_size[loaded_shard_id] = shard_shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
@ -976,6 +986,7 @@ class RowParallelLinear(LinearBase):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
# Special case for GGUF
@ -986,7 +997,10 @@ class RowParallelLinear(LinearBase):
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
weight_shape = list(loaded_weight.shape)
if input_dim:
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
param_data = param.data
if input_dim is not None:

View File

@ -5,7 +5,6 @@ import torch
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
@ -39,9 +38,6 @@ class GGUFConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
if get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"GGUF quantization hasn't supported tensor parallelism yet.")
return cls()
def get_quant_method(self, layer: torch.nn.Module,