From 620fc2d09ed86a4be6e9655ba94e771620a9e697 Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Sat, 5 Apr 2025 21:23:40 -0700 Subject: [PATCH] [Model] fix model testing for TeleChat2ForCausalLM and V0 llama4 (#16112) Signed-off-by: Lu Fang --- vllm/attention/backends/flash_attn.py | 5 +++++ vllm/model_executor/models/telechat2.py | 8 ++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 27bd292b51f22..c0a572b4aaea3 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -617,10 +617,15 @@ class FlashAttentionImpl(AttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + use_irope: bool = False, ) -> None: if blocksparse_params is not None: raise ValueError( "FlashAttention does not support block-sparse attention.") + if use_irope: + logger.warning( + "Using irope in V0 is not supported yet, it will fall back " + "to global attention for long context.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py index a38035e37ec73..062b1c2cf5f54 100644 --- a/vllm/model_executor/models/telechat2.py +++ b/vllm/model_executor/models/telechat2.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Set, Tuple +from typing import Iterable, Set, Tuple, Type import torch @@ -27,6 +27,7 @@ from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel +from .llama import LlamaDecoderLayer from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, is_pp_missing_parameter) @@ -120,7 +121,10 @@ class TeleChat2ForCausalLM(LlamaForCausalLM): }, ) - def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): + def _init_model(self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) def load_weights(self, weights: Iterable[Tuple[str,