Skip to content

Commit

Permalink
Better reprs and QOL (#96)
Browse files Browse the repository at this point in the history
* Add serialization to dict at the video object level

* Store and retrieve video backend metadata

* Try to restore source video when available

* Rename symbol

* Use backend metadata when available when serializing

* Fix backend metadata factory

* Re-embed videos when saving labels with embedded videos

* Fix serialization and logic for checking for embedded images

* Fix multi-frame decoding

* Fix docstring order

* Add method to embed a list of frames and update the objects

* Fix order of operations

* Add embed_videos

* Fix mid-level embedding function

* Hash videos by ID

* Add property to return embedded frame indices

* Hash LabeledFrame by ID and add convenience checks for instance types

* Labels.user_labeled_frames

* Fix JABS

* Tests

* Add live coverage support

* Expose high level embedding

* Separate replace video and support restoring source

* Update method

* Append/extend

* Skeleton.edge_names

* Skeleton repr

* Instance repr

* Add better video docstrings

* High level filename replacement

* Type hinting

* Lint

* Add Video(filename) syntactic sugar

* Shim for py3.8

* Coverage

* Shim

* Windows fix

* Windows test fix

* PredictedInstance repr

* Windows test fix again

* Fix test
  • Loading branch information
talmo authored Jun 5, 2024
1 parent f6f939b commit cd96be4
Show file tree
Hide file tree
Showing 8 changed files with 346 additions and 9 deletions.
5 changes: 4 additions & 1 deletion sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def load_video(filename: str, **kwargs) -> Video:
"""Load a video file.
Args:
filename: Path to a video file.
filename: The filename(s) of the video. Supported extensions: "mp4", "avi",
"mov", "mj2", "mkv", "h5", "hdf5", "slp", "png", "jpg", "jpeg", "tif",
"tiff", "bmp". If the filename is a list, a list of image filenames are
expected. If filename is a folder, it will be searched for images.
Returns:
A `Video` object.
Expand Down
23 changes: 23 additions & 0 deletions sleap_io/model/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,13 @@ def __len__(self) -> int:
"""Return the number of points in the instance."""
return len(self.points)

def __repr__(self) -> str:
"""Return a readable representation of the instance."""
pts = self.numpy().tolist()
track = f'"{self.track.name}"' if self.track is not None else self.track

return f"Instance(points={pts}, track={track})"

@property
def n_visible(self) -> int:
"""Return the number of visible points in the instance."""
Expand Down Expand Up @@ -327,6 +334,22 @@ class PredictedInstance(Instance):
score: float = 0.0
tracking_score: Optional[float] = 0

def __repr__(self) -> str:
"""Return a readable representation of the instance."""
pts = self.numpy().tolist()
track = f'"{self.track.name}"' if self.track is not None else self.track

score = str(self.score) if self.score is None else f"{self.score:.2f}"
tracking_score = (
str(self.tracking_score)
if self.tracking_score is None
else f"{self.tracking_score:.2f}"
)
return (
f"PredictedInstance(points={pts}, track={track}, "
f"score={score}, tracking_score={tracking_score})"
)

@classmethod
def from_numpy( # type: ignore[override]
cls,
Expand Down
140 changes: 137 additions & 3 deletions sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from attrs import define, field
from typing import Union, Optional, Any
import numpy as np

from pathlib import Path
from sleap_io.model.skeleton import Skeleton


Expand Down Expand Up @@ -57,6 +57,14 @@ class Labels:

def __attrs_post_init__(self):
"""Append videos, skeletons, and tracks seen in `labeled_frames` to `Labels`."""
self.update()

def update(self):
"""Update data structures based on contents.
This function will update the list of skeletons, videos and tracks from the
labeled frames, instances and suggestions.
"""
for lf in self.labeled_frames:
if lf.video not in self.videos:
self.videos.append(lf.video)
Expand All @@ -68,7 +76,13 @@ def __attrs_post_init__(self):
if inst.track is not None and inst.track not in self.tracks:
self.tracks.append(inst.track)

def __getitem__(self, key: int) -> list[LabeledFrame] | LabeledFrame:
for sf in self.suggestions:
if sf.video not in self.videos:
self.videos.append(sf.video)

def __getitem__(
self, key: int | slice | list[int] | np.ndarray | tuple[Video, int]
) -> list[LabeledFrame] | LabeledFrame:
"""Return one or more labeled frames based on indexing criteria."""
if type(key) == int:
return self.labeled_frames[key]
Expand Down Expand Up @@ -111,14 +125,58 @@ def __repr__(self) -> str:
f"labeled_frames={len(self.labeled_frames)}, "
f"videos={len(self.videos)}, "
f"skeletons={len(self.skeletons)}, "
f"tracks={len(self.tracks)}"
f"tracks={len(self.tracks)}, "
f"suggestions={len(self.suggestions)}"
")"
)

def __str__(self) -> str:
"""Return a readable representation of the labels."""
return self.__repr__()

def append(self, lf: LabeledFrame, update: bool = True):
"""Append a labeled frame to the labels.
Args:
lf: A labeled frame to add to the labels.
update: If `True` (the default), update list of videos, tracks and
skeletons from the contents.
"""
self.labeled_frames.append(lf)

if update:
if lf.video not in self.videos:
self.videos.append(lf.video)

for inst in lf:
if inst.skeleton not in self.skeletons:
self.skeletons.append(inst.skeleton)

if inst.track is not None and inst.track not in self.tracks:
self.tracks.append(inst.track)

def extend(self, lfs: list[LabeledFrame], update: bool = True):
"""Append a labeled frame to the labels.
Args:
lfs: A list of labeled frames to add to the labels.
update: If `True` (the default), update list of videos, tracks and
skeletons from the contents.
"""
self.labeled_frames.extend(lfs)

if update:
for lf in lfs:
if lf.video not in self.videos:
self.videos.append(lf.video)

for inst in lf:
if inst.skeleton not in self.skeletons:
self.skeletons.append(inst.skeleton)

if inst.track is not None and inst.track not in self.tracks:
self.tracks.append(inst.track)

def numpy(
self,
video: Optional[Union[Video, int]] = None,
Expand Down Expand Up @@ -417,3 +475,79 @@ def replace_videos(
for sf in self.suggestions:
if sf.video in video_map:
sf.video = video_map[sf.video]

def replace_filenames(
self,
new_filenames: list[str | Path] | None = None,
filename_map: dict[str | Path, str | Path] | None = None,
prefix_map: dict[str | Path, str | Path] | None = None,
):
"""Replace video filenames.
Args:
new_filenames: List of new filenames. Must have the same length as the
number of videos in the labels.
filename_map: Dictionary mapping old filenames (keys) to new filenames
(values).
prefix_map: Dictonary mapping old prefixes (keys) to new prefixes (values).
Notes:
Only one of the argument types can be provided.
"""
n = 0
if new_filenames is not None:
n += 1
if filename_map is not None:
n += 1
if prefix_map is not None:
n += 1
if n != 1:
raise ValueError(
"Exactly one input method must be provided to replace filenames."
)

if new_filenames is not None:
if len(self.videos) != len(new_filenames):
raise ValueError(
f"Number of new filenames ({len(new_filenames)}) does not match "
f"the number of videos ({len(self.videos)})."
)

for video, new_filename in zip(self.videos, new_filenames):
video.replace_filename(new_filename)

elif filename_map is not None:
for video in self.videos:
for old_fn, new_fn in filename_map.items():
if type(video.filename) == list:
new_fns = []
for fn in video.filename:
if Path(fn) == Path(old_fn):
new_fns.append(new_fn)
else:
new_fns.append(fn)
video.replace_filename(new_fns)
else:
if Path(video.filename) == Path(old_fn):
video.replace_filename(new_fn)

elif prefix_map is not None:
for video in self.videos:
for old_prefix, new_prefix in prefix_map.items():
old_prefix, new_prefix = Path(old_prefix), Path(new_prefix)

if type(video.filename) == list:
new_fns = []
for fn in video.filename:
fn = Path(fn)
if fn.as_posix().startswith(old_prefix.as_posix()):
new_fns.append(new_prefix / fn.relative_to(old_prefix))
else:
new_fns.append(fn)
video.replace_filename(new_fns)
else:
fn = Path(video.filename)
if fn.as_posix().startswith(old_prefix.as_posix()):
video.replace_filename(
new_prefix / fn.relative_to(old_prefix)
)
10 changes: 10 additions & 0 deletions sleap_io/model/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def edge_inds(self) -> list[Tuple[int, int]]:
for edge in self.edges
]

@property
def edge_names(self) -> list[str, str]:
"""Edge names as a list of 2-tuples with string node names."""
return [(edge.source.name, edge.destination.name) for edge in self.edges]

@property
def flipped_node_inds(self) -> list[int]:
"""Returns node indices that should be switched when horizontally flipping."""
Expand All @@ -183,6 +188,11 @@ def __len__(self) -> int:
"""Return the number of nodes in the skeleton."""
return len(self.nodes)

def __repr__(self) -> str:
"""Return a readable representation of the skeleton."""
nodes = ", ".join([f'"{node}"' for node in self.node_names])
return "Skeleton(" f"nodes=[{nodes}], " f"edges={self.edge_inds}" ")"

def index(self, node: Node | str) -> int:
"""Return the index of a node specified as a `Node` or string name."""
if type(node) == str:
Expand Down
19 changes: 15 additions & 4 deletions sleap_io/model/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import attrs
from typing import Tuple, Optional, Optional
import numpy as np
from sleap_io.io.video import VideoBackend, MediaVideo, HDF5Video
from sleap_io.io.video import VideoBackend, MediaVideo, HDF5Video, ImageVideo
from pathlib import Path


Expand All @@ -23,7 +23,10 @@ class Video:
backend appropriately.
Attributes:
filename: The filename(s) of the video.
filename: The filename(s) of the video. Supported extensions: "mp4", "avi",
"mov", "mj2", "mkv", "h5", "hdf5", "slp", "png", "jpg", "jpeg", "tif",
"tiff", "bmp". If the filename is a list, a list of image filenames are
expected. If filename is a folder, it will be searched for images.
backend: An object that implements the basic methods for reading and
manipulating frames of a specific video type.
backend_metadata: A dictionary of metadata specific to the backend. This is
Expand All @@ -45,7 +48,12 @@ class Video:
backend_metadata: dict[str, any] = attrs.field(factory=dict)
source_video: Optional[Video] = None

EXTS = MediaVideo.EXTS + HDF5Video.EXTS
EXTS = MediaVideo.EXTS + HDF5Video.EXTS + ImageVideo.EXTS

def __attrs_post_init__(self):
"""Post init syntactic sugar."""
if self.backend is None and self.exists():
self.open()

def __attrs_post_init__(self):
"""Post init syntactic sugar."""
Expand All @@ -65,7 +73,10 @@ def from_filename(
"""Create a Video from a filename.
Args:
filename: Path to video file(s).
filename: The filename(s) of the video. Supported extensions: "mp4", "avi",
"mov", "mj2", "mkv", "h5", "hdf5", "slp", "png", "jpg", "jpeg", "tif",
"tiff", "bmp". If the filename is a list, a list of image filenames are
expected. If filename is a folder, it will be searched for images.
dataset: Name of dataset in HDF5 file.
grayscale: Whether to force grayscale. If None, autodetect on first frame
load.
Expand Down
9 changes: 9 additions & 0 deletions tests/model/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def test_instance():
inst = Instance({"A": [0, 1], "B": [2, 3]}, skeleton=Skeleton(["A", "B"]))
assert_equal(inst.numpy(), [[0, 1], [2, 3]])
assert type(inst["A"]) == Point
assert str(inst) == "Instance(points=[[0.0, 1.0], [2.0, 3.0]], track=None)"

inst.track = Track("trk")
assert str(inst) == 'Instance(points=[[0.0, 1.0], [2.0, 3.0]], track="trk")'

inst = Instance({"A": [0, 1]}, skeleton=Skeleton(["A", "B"]))
assert_equal(inst.numpy(), [[0, 1], [np.nan, np.nan]])
Expand Down Expand Up @@ -155,3 +159,8 @@ def test_predicted_instance():
assert inst[0].score == 0.4
assert inst[1].score == 0.5
assert inst.score == 0.6

assert (
str(inst) == "PredictedInstance(points=[[0.0, 1.0], [2.0, 3.0]], track=None, "
"score=0.60, tracking_score=None)"
)
Loading

0 comments on commit cd96be4

Please sign in to comment.