mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 10:16:31 +08:00
[Model] Enable quantization support for transformers backend (#12960)
This commit is contained in:
parent
efbe854448
commit
67ef8f666a
@ -42,7 +42,7 @@ Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project
|
||||
|
||||
### Transformers fallback
|
||||
|
||||
After the merge of <gh-pr:11330>, `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!
|
||||
`vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!
|
||||
|
||||
To check if the backend is `transformers`, you can simply do this:
|
||||
|
||||
@ -56,9 +56,13 @@ If it is `TransformersModel` then it means it's based on `transformers`!
|
||||
|
||||
#### Supported features
|
||||
|
||||
##### LORA and quantization
|
||||
##### Quantization
|
||||
|
||||
Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
|
||||
Transformers fallback has supported most of available quantization in vLLM (except GGUF). See [Quantization page](#quantization-index) for more information about supported quantization in vllm.
|
||||
|
||||
##### LoRA
|
||||
|
||||
LoRA hasn't supported on transformers fallback yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
|
||||
|
||||
Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.
|
||||
|
||||
|
||||
@ -45,10 +45,14 @@ def check_implementation(
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
|
||||
("openai-community/gpt2", "transformers"),
|
||||
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "auto"),
|
||||
]) # trust_remote_code=True by default
|
||||
def test_models(hf_runner, vllm_runner, example_prompts, model,
|
||||
model_impl) -> None:
|
||||
def test_models(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
example_prompts: list[str],
|
||||
model: str,
|
||||
model_impl: str,
|
||||
) -> None:
|
||||
|
||||
maybe_raises = nullcontext()
|
||||
if model == "openai-community/gpt2" and model_impl == "transformers":
|
||||
@ -67,10 +71,50 @@ def test_models(hf_runner, vllm_runner, example_prompts, model,
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_distributed(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
example_prompts,
|
||||
):
|
||||
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
|
||||
check_implementation(hf_runner, vllm_runner, example_prompts,
|
||||
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, quantization_kwargs", [
|
||||
(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
{
|
||||
"quantization": "bitsandbytes",
|
||||
"load_format": "bitsandbytes",
|
||||
},
|
||||
),
|
||||
])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_quantization(
|
||||
vllm_runner: Type[VllmRunner],
|
||||
example_prompts: list[str],
|
||||
model: str,
|
||||
quantization_kwargs: dict[str, str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
with vllm_runner(
|
||||
model, model_impl="auto", enforce_eager=True,
|
||||
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
model_impl="transformers",
|
||||
enforce_eager=True,
|
||||
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||
transformers_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=transformers_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="transformers",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
@ -28,6 +28,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@ -37,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsQuant
|
||||
from .utils import maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -50,10 +52,10 @@ def vllm_flash_attention_forward(
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
# Transformers kwargs
|
||||
scaling: float = None,
|
||||
scaling: Optional[float] = None,
|
||||
# vLLM kwargs
|
||||
attn_metadata: AttentionMetadata = None,
|
||||
attention_instances: list[Attention] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
attention_instances: Optional[list[Attention]] = None,
|
||||
**kwargs):
|
||||
self_attn = attention_instances[module.layer_idx]
|
||||
if scaling is not None:
|
||||
@ -99,13 +101,7 @@ def replace_linear_class(
|
||||
vllm_linear_cls = {
|
||||
"colwise": ColumnParallelLinear,
|
||||
"rowwise": RowParallelLinear,
|
||||
}.get(style)
|
||||
|
||||
if vllm_linear_cls is None:
|
||||
logger.warning(
|
||||
"Unsupported parallel style value: %s. "
|
||||
"This layer will not be tensor parallelized.", style)
|
||||
return linear
|
||||
}.get(style, ReplicatedLinear)
|
||||
|
||||
class HFCompatibleLinear(vllm_linear_cls):
|
||||
"""
|
||||
@ -119,10 +115,11 @@ def replace_linear_class(
|
||||
input_size=linear.in_features,
|
||||
output_size=linear.out_features,
|
||||
bias=linear.bias is not None,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
|
||||
class TransformersModel(nn.Module):
|
||||
class TransformersModel(nn.Module, SupportsQuant):
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
embedding_modules = ["embed_tokens"
|
||||
] # TODO transformers will have a util to get it
|
||||
@ -133,10 +130,8 @@ class TransformersModel(nn.Module):
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
|
||||
@ -162,7 +157,7 @@ class TransformersModel(nn.Module):
|
||||
scale=config.head_dim**-0.5,
|
||||
num_kv_heads=divide(config.num_key_value_heads, tp_size),
|
||||
cache_config=cache_config,
|
||||
quant_config=None,
|
||||
quant_config=self.quant_config,
|
||||
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
@ -172,7 +167,7 @@ class TransformersModel(nn.Module):
|
||||
# ForCausalLM modifications
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=None,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.get_input_embeddings().weight
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user