fix: resolve mypy type errors in isaac_patch_hf_runner

Signed-off-by: Oscar Gonzalez <ogonzal6@alumni.jh.edu>
This commit is contained in:
Oscar Gonzalez 2025-12-22 18:24:58 -05:00
parent af3529f651
commit 4789b99110

View File

@ -145,9 +145,11 @@ def isaac_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
)
# Initialize and collect hidden states
all_hidden_states = ()
hidden_states = inputs_embeds
all_hidden_states += (hidden_states,)
hidden_states_list: list[torch.Tensor] = []
if output_hidden_states:
hidden_states_list.append(hidden_states)
for decoder_layer in self.layers:
layer_outputs = decoder_layer(
@ -164,11 +166,18 @@ def isaac_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hidden_states = (
layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs
)
all_hidden_states += (hidden_states,)
if output_hidden_states:
hidden_states_list.append(hidden_states)
# Final layer norm
hidden_states = self.norm(hidden_states)
all_hidden_states += (hidden_states,)
if output_hidden_states:
hidden_states_list.append(hidden_states)
# Convert to tuple or None
all_hidden_states = tuple(hidden_states_list) if output_hidden_states else None
# Include hiden_states for compatibility with hidden_states_to_seq_logprobs()
return BaseModelOutputWithPast(