mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 13:56:32 +08:00
[Core] Support tensor parallelism for GGUF quantization (#7520)
This commit is contained in:
parent
47b65a5508
commit
7601cb044d
@ -7,6 +7,7 @@ import os
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
@ -20,7 +21,7 @@ MAX_MODEL_LEN = 1024
|
||||
MODELS = [
|
||||
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
||||
filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")),
|
||||
filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")),
|
||||
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF",
|
||||
filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")),
|
||||
@ -39,22 +40,36 @@ MODELS = [
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
def test_models(
|
||||
num_gpus_available,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tp_size: int,
|
||||
) -> None:
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
original_model, gguf_model = model
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(original_model)
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': prompt
|
||||
}] for prompt in example_prompts]
|
||||
example_prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
# Run unquantized model.
|
||||
with vllm_runner(model_name=original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1) as original_model:
|
||||
tensor_parallel_size=tp_size) as original_model:
|
||||
|
||||
original_outputs = original_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
@ -63,8 +78,7 @@ def test_models(
|
||||
with vllm_runner(model_name=gguf_model,
|
||||
dtype=dtype,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1) as gguf_model:
|
||||
tensor_parallel_size=tp_size) as gguf_model:
|
||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
|
||||
|
||||
@ -507,11 +507,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id
|
||||
|
||||
if is_gguf_weight:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * \
|
||||
loaded_shard_id
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_shape = list(loaded_weight.shape)
|
||||
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_size[loaded_shard_id] = loaded_weight.shape
|
||||
param.shard_size[loaded_shard_id] = shard_shape
|
||||
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
input_size = loaded_weight.shape[input_dim]
|
||||
param_data = param_data.narrow(input_dim, 0, input_size)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
@ -863,8 +868,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param, orig_qkv_offsets, loaded_shard_id)
|
||||
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_shape = list(loaded_weight.shape)
|
||||
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_size[loaded_shard_id] = loaded_weight.shape
|
||||
param.shard_size[loaded_shard_id] = shard_shape
|
||||
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
input_size = loaded_weight.shape[input_dim]
|
||||
param_data = param_data.narrow(input_dim, 0, input_size)
|
||||
@ -976,6 +986,7 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
|
||||
# Special case for GGUF
|
||||
@ -986,7 +997,10 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
weight_shape = list(loaded_weight.shape)
|
||||
if input_dim:
|
||||
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
|
||||
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if input_dim is not None:
|
||||
|
||||
@ -5,7 +5,6 @@ import torch
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
@ -39,9 +38,6 @@ class GGUFConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
raise ValueError(
|
||||
"GGUF quantization hasn't supported tensor parallelism yet.")
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user