mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 01:25:25 +08:00
[Bugfix] Ignore GPTQ quantization of Qwen2-VL visual module (#10169)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
e0191a95d8
commit
f83feccd7f
@ -51,7 +51,9 @@ from vllm.model_executor.layers.activation import QuickGELU
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import (GPTQConfig,
|
||||||
|
GPTQMarlinConfig,
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -982,7 +984,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.visual = Qwen2VisionTransformer(
|
self.visual = Qwen2VisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
quant_config=quant_config,
|
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||||
prefix="visual",
|
prefix="visual",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1008,6 +1010,14 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
make_empty_intermediate_tensors_factory(
|
make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size))
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
|
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||||
|
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||||
|
# seems to avoid vision encoder sections for some models.
|
||||||
|
# See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
|
||||||
|
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||||
|
return None
|
||||||
|
return quant_config
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(self,
|
def _validate_and_reshape_mm_tensor(self,
|
||||||
mm_input: Union[torch.Tensor,
|
mm_input: Union[torch.Tensor,
|
||||||
List[torch.Tensor]],
|
List[torch.Tensor]],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user