diff --git a/tests/conftest.py b/tests/conftest.py index 11c573befb2d2..317b36ba6cb80 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -459,14 +459,17 @@ class HfRunner: embeddings.append(embedding) return embeddings - def classify(self, prompts: list[str]) -> list[str]: + def classify(self, prompts: list[str]) -> list[list[float]]: # output is final logits all_inputs = self.get_inputs(prompts) - outputs = [] + outputs: list[list[float]] = [] problem_type = getattr(self.config, "problem_type", "") for inputs in all_inputs: output = self.model(**self.wrap_device(inputs)) + + assert isinstance(output.logits, torch.Tensor) + if problem_type == "regression": logits = output.logits[0].tolist() elif problem_type == "multi_label_classification":