mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
Enable CUDA graph support for llama 3.2 vision (#14917)
Signed-off-by: Matt Ritter <100659061+mritterfigma@users.noreply.github.com>
This commit is contained in:
parent
2f726b241e
commit
a8652f4f0f
@ -215,7 +215,6 @@ def _run_test(
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
|
||||
}) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
@ -425,7 +424,6 @@ def test_bnb_regression(
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
enforce_eager=True,
|
||||
quantization="bitsandbytes",
|
||||
load_format="bitsandbytes",
|
||||
)
|
||||
@ -481,7 +479,6 @@ def test_explicit_implicit_prompt(
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=1,
|
||||
enforce_eager=True,
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
@ -513,7 +510,6 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=1,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image":
|
||||
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
|
||||
|
||||
|
||||
@ -670,14 +670,6 @@ class ModelConfig:
|
||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
||||
self.max_model_len)
|
||||
|
||||
MODEL_NOT_SUPPORT_CUDA_GRAPH = ['mllama']
|
||||
if (self.hf_config.model_type in MODEL_NOT_SUPPORT_CUDA_GRAPH
|
||||
and not self.enforce_eager):
|
||||
logger.warning(
|
||||
"CUDA graph is not supported for %s yet, fallback to the eager "
|
||||
"mode.", self.hf_config.model_type)
|
||||
self.enforce_eager = True
|
||||
|
||||
def _verify_bnb_config(self) -> None:
|
||||
"""
|
||||
The current version of bitsandbytes (0.44.0) with 8-bit models does not
|
||||
|
||||
@ -1368,7 +1368,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
full_text_row_masked_out_mask = (
|
||||
attn_metadata.encoder_seq_lens_tensor
|
||||
!= 0).reshape(-1, 1).to(input_ids.device)
|
||||
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
|
||||
skip_cross_attention = attn_metadata.max_encoder_seq_len == 0
|
||||
|
||||
# For image-present prefill.
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user