hfrunner.classify should return list[list[float]] not list[str] (#29671)

Signed-off-by: Chukwuma Nwaugha <nwaughac@gmail.com>
This commit is contained in:
Chukwuma Nwaugha 2025-11-29 13:57:00 +00:00 committed by GitHub
parent f4341f45d3
commit ad7f714d62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -459,14 +459,17 @@ class HfRunner:
embeddings.append(embedding)
return embeddings
def classify(self, prompts: list[str]) -> list[str]:
def classify(self, prompts: list[str]) -> list[list[float]]:
# output is final logits
all_inputs = self.get_inputs(prompts)
outputs = []
outputs: list[list[float]] = []
problem_type = getattr(self.config, "problem_type", "")
for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs))
assert isinstance(output.logits, torch.Tensor)
if problem_type == "regression":
logits = output.logits[0].tolist()
elif problem_type == "multi_label_classification":