diff --git a/torchbenchmark/models/torch_multimodal_clip/__init__.py b/torchbenchmark/models/torch_multimodal_clip/__init__.py index 6802ec932..fda18015f 100644 --- a/torchbenchmark/models/torch_multimodal_clip/__init__.py +++ b/torchbenchmark/models/torch_multimodal_clip/__init__.py @@ -87,4 +87,5 @@ def eval(self): ) score = image_embedding @ text_embedding.t() - return self.text[torch.argmax(score)] + indices = torch.argmax(score, dim=1) + return [self.texts[i][indices[i].item()] for i in range(self.batch_size)]