mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 03:29:06 +08:00
[Model] Enable quantization support for transformers backend (#12960)
This commit is contained in:
parent
efbe854448
commit
67ef8f666a
@ -42,7 +42,7 @@ Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project
|
|||||||
|
|
||||||
### Transformers fallback
|
### Transformers fallback
|
||||||
|
|
||||||
After the merge of <gh-pr:11330>, `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!
|
`vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!
|
||||||
|
|
||||||
To check if the backend is `transformers`, you can simply do this:
|
To check if the backend is `transformers`, you can simply do this:
|
||||||
|
|
||||||
@ -56,9 +56,13 @@ If it is `TransformersModel` then it means it's based on `transformers`!
|
|||||||
|
|
||||||
#### Supported features
|
#### Supported features
|
||||||
|
|
||||||
##### LORA and quantization
|
##### Quantization
|
||||||
|
|
||||||
Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
|
Transformers fallback has supported most of available quantization in vLLM (except GGUF). See [Quantization page](#quantization-index) for more information about supported quantization in vllm.
|
||||||
|
|
||||||
|
##### LoRA
|
||||||
|
|
||||||
|
LoRA hasn't supported on transformers fallback yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
|
||||||
|
|
||||||
Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.
|
Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.
|
||||||
|
|
||||||
|
|||||||
@ -45,10 +45,14 @@ def check_implementation(
|
|||||||
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
|
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
|
||||||
("openai-community/gpt2", "transformers"),
|
("openai-community/gpt2", "transformers"),
|
||||||
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
|
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", "auto"),
|
|
||||||
]) # trust_remote_code=True by default
|
]) # trust_remote_code=True by default
|
||||||
def test_models(hf_runner, vllm_runner, example_prompts, model,
|
def test_models(
|
||||||
model_impl) -> None:
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
example_prompts: list[str],
|
||||||
|
model: str,
|
||||||
|
model_impl: str,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
maybe_raises = nullcontext()
|
maybe_raises = nullcontext()
|
||||||
if model == "openai-community/gpt2" and model_impl == "transformers":
|
if model == "openai-community/gpt2" and model_impl == "transformers":
|
||||||
@ -67,10 +71,50 @@ def test_models(hf_runner, vllm_runner, example_prompts, model,
|
|||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
def test_distributed(
|
def test_distributed(
|
||||||
hf_runner,
|
hf_runner: Type[HfRunner],
|
||||||
vllm_runner,
|
vllm_runner: Type[VllmRunner],
|
||||||
example_prompts,
|
example_prompts,
|
||||||
):
|
):
|
||||||
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
|
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
|
||||||
check_implementation(hf_runner, vllm_runner, example_prompts,
|
check_implementation(hf_runner, vllm_runner, example_prompts,
|
||||||
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)
|
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model, quantization_kwargs", [
|
||||||
|
(
|
||||||
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
{
|
||||||
|
"quantization": "bitsandbytes",
|
||||||
|
"load_format": "bitsandbytes",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_quantization(
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
example_prompts: list[str],
|
||||||
|
model: str,
|
||||||
|
quantization_kwargs: dict[str, str],
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(
|
||||||
|
model, model_impl="auto", enforce_eager=True,
|
||||||
|
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||||
|
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
model_impl="transformers",
|
||||||
|
enforce_eager=True,
|
||||||
|
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||||
|
transformers_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=transformers_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="transformers",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|||||||
from vllm.distributed.utils import divide
|
from vllm.distributed.utils import divide
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
@ -37,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces import SupportsQuant
|
||||||
from .utils import maybe_prefix
|
from .utils import maybe_prefix
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -50,10 +52,10 @@ def vllm_flash_attention_forward(
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
# Transformers kwargs
|
# Transformers kwargs
|
||||||
scaling: float = None,
|
scaling: Optional[float] = None,
|
||||||
# vLLM kwargs
|
# vLLM kwargs
|
||||||
attn_metadata: AttentionMetadata = None,
|
attn_metadata: Optional[AttentionMetadata] = None,
|
||||||
attention_instances: list[Attention] = None,
|
attention_instances: Optional[list[Attention]] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self_attn = attention_instances[module.layer_idx]
|
self_attn = attention_instances[module.layer_idx]
|
||||||
if scaling is not None:
|
if scaling is not None:
|
||||||
@ -99,13 +101,7 @@ def replace_linear_class(
|
|||||||
vllm_linear_cls = {
|
vllm_linear_cls = {
|
||||||
"colwise": ColumnParallelLinear,
|
"colwise": ColumnParallelLinear,
|
||||||
"rowwise": RowParallelLinear,
|
"rowwise": RowParallelLinear,
|
||||||
}.get(style)
|
}.get(style, ReplicatedLinear)
|
||||||
|
|
||||||
if vllm_linear_cls is None:
|
|
||||||
logger.warning(
|
|
||||||
"Unsupported parallel style value: %s. "
|
|
||||||
"This layer will not be tensor parallelized.", style)
|
|
||||||
return linear
|
|
||||||
|
|
||||||
class HFCompatibleLinear(vllm_linear_cls):
|
class HFCompatibleLinear(vllm_linear_cls):
|
||||||
"""
|
"""
|
||||||
@ -119,10 +115,11 @@ def replace_linear_class(
|
|||||||
input_size=linear.in_features,
|
input_size=linear.in_features,
|
||||||
output_size=linear.out_features,
|
output_size=linear.out_features,
|
||||||
bias=linear.bias is not None,
|
bias=linear.bias is not None,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(nn.Module):
|
class TransformersModel(nn.Module, SupportsQuant):
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
embedding_modules = ["embed_tokens"
|
embedding_modules = ["embed_tokens"
|
||||||
] # TODO transformers will have a util to get it
|
] # TODO transformers will have a util to get it
|
||||||
@ -133,10 +130,8 @@ class TransformersModel(nn.Module):
|
|||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
|
||||||
@ -162,7 +157,7 @@ class TransformersModel(nn.Module):
|
|||||||
scale=config.head_dim**-0.5,
|
scale=config.head_dim**-0.5,
|
||||||
num_kv_heads=divide(config.num_key_value_heads, tp_size),
|
num_kv_heads=divide(config.num_key_value_heads, tp_size),
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=None,
|
quant_config=self.quant_config,
|
||||||
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
|
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -172,7 +167,7 @@ class TransformersModel(nn.Module):
|
|||||||
# ForCausalLM modifications
|
# ForCausalLM modifications
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=None,
|
quant_config=self.quant_config,
|
||||||
prefix=maybe_prefix(prefix, "lm_head"))
|
prefix=maybe_prefix(prefix, "lm_head"))
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
self.lm_head.weight = self.model.get_input_embeddings().weight
|
self.lm_head.weight = self.model.get_input_embeddings().weight
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user