Skip to content

Commit

Permalink
typefix
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Jan 8, 2025
1 parent 0b29244 commit 2b7347d
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/torchmetrics/functional/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@ def _detect_modality(input_data: Union[Tensor, List[Tensor], List[str], str]) ->
raise ValueError("Could not automatically determine modality for input_data")


def _process_image_data(images: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]:
def _process_image_data(images: Union[Tensor, List[Tensor]]) -> List[Tensor]:
"""Helper function to process image data."""
if not isinstance(images, list) and images.ndim == 3:
images = [images]
images = [images] if not isinstance(images, list) and images.ndim == 3 else list(images)
if not all(i.ndim == 3 for i in images):
raise ValueError("Expected all images to be 3d but found image that has either more or less")
return images
Expand Down Expand Up @@ -148,12 +147,12 @@ def _clip_score_update(
else _process_text_data(cast(Union[str, List[str]], target))
)

# Verify matching lengths
if len(source_data) != len(target_data):
raise ValueError(
"Expected the number of source and target examples to be the same but got "
f"{len(source_data)} and {len(target_data)}"
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if source_modality == "image" and isinstance(source_data[0], Tensor):
device = source_data[0].device
Expand Down

0 comments on commit 2b7347d

Please sign in to comment.