From bab4dea597e3ecefd665ca9b530225407adfbcb8 Mon Sep 17 00:00:00 2001 From: Oscar Gonzalez Date: Tue, 23 Dec 2025 15:43:21 -0500 Subject: [PATCH] Add perceptron dependency for Isaac tests and refactor tests. Signed-off-by: Oscar Gonzalez --- requirements/test.in | 2 + requirements/test.txt | 21 ++- .../multimodal/generation/test_common.py | 25 +++ .../generation/vlm_utils/model_utils.py | 177 ++++++++++++++++++ 4 files changed, 223 insertions(+), 2 deletions(-) diff --git a/requirements/test.in b/requirements/test.in index 55452ce83f232..68e607ff7e308 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -57,3 +57,5 @@ pydantic>=2.12 # 2.11 leads to error on python 3.13 decord==0.6.0 terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test gpt-oss >= 0.0.7; python_version > '3.11' + +perceptron # required for isaac test diff --git a/requirements/test.txt b/requirements/test.txt index ea2093e4347fe..843a8212b819f 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -135,6 +135,7 @@ cloudpickle==3.1.1 # via mlflow-skinny colorama==0.4.6 # via + # perceptron # sacrebleu # schemathesis # tqdm-multiprocess @@ -302,6 +303,8 @@ h11==0.14.0 # via # httpcore # uvicorn +h2==4.3.0 + # via httpx h5py==3.13.0 # via terratorch harfile==0.3.0 @@ -310,6 +313,8 @@ hf-xet==1.1.7 # via huggingface-hub hiredis==3.0.0 # via tensorizer +hpack==4.1.0 + # via h2 html2text==2025.4.15 # via gpt-oss httpcore==1.0.6 @@ -317,6 +322,7 @@ httpcore==1.0.6 httpx==0.27.2 # via # -r requirements/test.in + # perceptron # schemathesis huggingface-hub==0.34.3 # via @@ -338,6 +344,8 @@ hydra-core==1.3.2 # via # lightly # lightning +hyperframe==6.1.0 + # via h2 hypothesis==6.131.0 # via # hypothesis-graphql @@ -549,6 +557,7 @@ numpy==1.26.4 # pandas # patsy # peft + # perceptron # pycocotools # pyogrio # rasterio @@ -702,6 +711,8 @@ peft==0.16.0 # via # -r requirements/test.in # lm-eval +perceptron==0.1.4 + # via -r requirements/test.in pillow==10.4.0 # via # genai-perf @@ -709,6 +720,7 @@ pillow==10.4.0 # lightly-utils # matplotlib # mistral-common + # perceptron # scikit-image # segmentation-models-pytorch # sentence-transformers @@ -952,6 +964,7 @@ rich==13.9.4 # genai-perf # lightning # mteb + # perceptron # typer rioxarray==0.19.0 # via terratorch @@ -1024,7 +1037,9 @@ shapely==2.1.1 # geopandas # torchgeo shellingham==1.5.4 - # via typer + # via + # perceptron + # typer six==1.16.0 # via # junit-xml @@ -1218,7 +1233,9 @@ typepy==1.3.2 # pytablewriter # tabledata typer==0.15.2 - # via fastsafetensors + # via + # fastsafetensors + # perceptron types-python-dateutil==2.9.0.20241206 # via arrow typeshed-client==2.8.2 diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index c5a0b6748f797..17f45e79a7387 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -510,6 +510,31 @@ VLM_TEST_SETTINGS = { use_tokenizer_eos=True, auto_cls=AutoModelForImageTextToText, ), + "isaac": VLMTestInfo( + models=["PerceptronAI/Isaac-0.1"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: ( + f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n" + ), + img_idx_to_prompt=lambda idx: "", + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "Please describe the image shortly.", + "cherry_blossom": "Please infer the season with reason.", + } + ), + multi_image_prompt=( + "Picture 1: \n" + "Picture 2: \n" + "Describe these two images with one paragraph respectively." + ), + enforce_eager=False, + max_model_len=4096, + max_num_seqs=2, + hf_model_kwargs={"device_map": "auto"}, + patch_hf_runner=model_utils.isaac_patch_hf_runner, + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + ), "kimi_vl": VLMTestInfo( models=["moonshotai/Kimi-VL-A3B-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index b2c62fbd119cc..acc18021859b5 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -522,6 +522,183 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model +def isaac_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patch HF runner for Isaac: + 1) Move processor outputs to model device + 2) Ensure IsaacModel.forward returns hidden_states + for compatibility with hidden_states_to_seq_logprobs() + """ + + from perceptron.tensorstream import TextType + from perceptron.tensorstream.ops import compute_mrope_pos_tensor, modality_mask + from transformers.modeling_outputs import BaseModelOutputWithPast + + def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: + """ + Create 3D positional indices for token input. + """ + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE + return position_ids + + model_device = next(hf_model.model.parameters()).device + + # ---------------------------- + # 1) Patch processor: move BatchFeature input_ids and TensorStream to model device + # ---------------------------- + original_processor = hf_model.processor + + def patched_processor(*args, **kwargs): + result = original_processor(*args, **kwargs) + for k, v in result.data.items(): + result[k] = v.to(model_device) + return result + + hf_model.processor = patched_processor + + tokenizer = AutoTokenizer.from_pretrained( + hf_model.model_name, trust_remote_code=True + ) + + original_generate = hf_model.model.generate + + def patched_generate(*args, **kwargs): + kwargs["pad_token_id"] = tokenizer.eos_token_id + kwargs["eos_token_id"] = tokenizer.eos_token_id + return original_generate(*args, **kwargs) + + hf_model.model.generate = patched_generate + + # ---------------------------- + # 2) Patch IsaacModel.forward: add hidden_states to the output + # ---------------------------- + isaac_model = hf_model.model.model + + def patched_forward( + self, + input_ids=None, + tensor_stream=None, + attention_mask=None, + position_ids=None, + modality_tensor=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + **kwargs, + ): + """ + Forward pass with MRoPE position embeddings. + Computes position embeddings once and passes them through all layers. + """ + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Get inputs + if tensor_stream is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both tensor_stream and inputs_embeds") + elif tensor_stream is not None: + # Embed TensorStream directly + inputs_embeds = self.embed_stream(tensor_stream) + # Create modality tensor if not provided + if modality_tensor is None: + modality_tensor = modality_mask(tensor_stream) + elif input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + # Create text modality tensor if not provided + if modality_tensor is None: + batch_size, seq_length = input_ids.shape + modality_tensor = torch.full( + (batch_size, seq_length), + TextType.text.value, + device=input_ids.device, + dtype=torch.long, + ) + elif inputs_embeds is None: + raise ValueError( + "You have to specify either tensor_stream, input_ids or inputs_embeds" + ) + + # Create default position_ids if not provided + if position_ids is None: + if tensor_stream is not None: + position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) + else: + position_ids = compute_position_ids_input_ids(input_ids) + + # Compute MRoPE position embeddings if we have custom rotary_emb + cos, sin = self.rotary_emb(position_ids, modality_tensor) + cos = cos.to(inputs_embeds.dtype) + sin = sin.to(inputs_embeds.dtype) + + # Prepare attention mask + if attention_mask is not None: + attention_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, False + ) + + # Initialize and collect hidden states + hidden_states = inputs_embeds + hidden_states_list: list[torch.Tensor] = [] + + if output_hidden_states: + hidden_states_list.append(hidden_states) + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=(cos, sin), + **kwargs, + ) + + hidden_states = ( + layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs + ) + + if output_hidden_states: + hidden_states_list.append(hidden_states) + + # Final layer norm + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + hidden_states_list.append(hidden_states) + + # Convert to tuple or None + all_hidden_states = tuple(hidden_states_list) if output_hidden_states else None + + # Include hiden_states for compatibility with hidden_states_to_seq_logprobs() + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + ) + + isaac_model.forward = types.MethodType(patched_forward, isaac_model) + + return hf_model + + def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for SkyworkR1V."""