Skip to content

Commit

Permalink
Merge pull request #372 from SainsburyWellcomeCentre/bonsai-sleap0.3-…
Browse files Browse the repository at this point in the history
…PoseReader

Updated Pose Reader for Bonsai-Sleap0.3
  • Loading branch information
jkbhagatio authored Jul 3, 2024
2 parents 25cc4b7 + 4a72c47 commit cb7d84c
Showing 1 changed file with 55 additions and 39 deletions.
94 changes: 55 additions & 39 deletions aeon/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,23 +286,35 @@ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/process
# `pattern` for this reader should typically be '<hpcnode>_<jobid>*'
super().__init__(pattern, columns=None)
self._model_root = model_root
self.config_file = None # requires reading the data file to be set

def read(self, file: Path) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[1:])
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:])
config_file_dir = Path(self._model_root) / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
config_file = self.get_config_file(config_file_dir)
parts = self.get_bodyparts(config_file)
self.config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names()
parts = self.get_bodyparts()

# Using bodyparts, assign column names to Harp register values, and read data in default format.
columns = ["class", "class_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
try: # Bonsai.Sleap0.2
bonsai_sleap_v = 0.2
columns = ["identity", "identity_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
except ValueError: # column mismatch; Bonsai.Sleap0.3
bonsai_sleap_v = 0.3
columns = ["identity"]
columns.extend([f"{identity}_likelihood" for identity in identities])
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)

# Drop any repeat parts.
unique_parts, unique_idxs = np.unique(parts, return_index=True)
Expand All @@ -315,54 +327,74 @@ def read(self, file: Path) -> pd.DataFrame:
parts = unique_parts

# Set new columns, and reformat `data`.
data = self.class_int2str(data)
n_parts = len(parts)
part_data_list = [pd.DataFrame()] * n_parts
new_columns = ["class", "class_likelihood", "part", "x", "y", "part_likelihood"]
new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]
new_data = pd.DataFrame(columns=new_columns)
for i, part in enumerate(parts):
part_columns = ["class", "class_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
part_columns = columns[0 : (len(identities) + 1)] if bonsai_sleap_v == 0.3 else columns[0:2]
part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
part_data = pd.DataFrame(data[part_columns])
if bonsai_sleap_v == 0.3: # combine all identity_likelihood cols into a single col as dict
part_data["identity_likelihood"] = part_data.apply(
lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1
)
part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True)
part_data = part_data[ # reorder columns
["identity", "identity_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
]
part_data.insert(2, "part", part)
part_data.columns = new_columns
part_data_list[i] = part_data
new_data = pd.concat(part_data_list)
return new_data.sort_index()

def get_class_names(self, file: Path) -> list[str]:
def get_class_names(self) -> list[str]:
"""Returns a list of classes from a model's config file."""
classes = None
with open(file) as f:
with open(self.config_file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
if self.config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "class_vectors")["classes"]
except KeyError as err:
if not classes:
raise KeyError(f"Cannot find class_vectors in {file}.") from err
raise KeyError(f"Cannot find class_vectors in {self.config_file}.") from err
return classes

def get_bodyparts(self, file: Path) -> list[str]:
def get_bodyparts(self) -> list[str]:
"""Returns a list of bodyparts from a model's config file."""
parts = []
with open(file) as f:
with open(self.config_file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
if self.config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
parts = [util.find_nested_key(heads, "anchor_part")]
parts += util.find_nested_key(heads, "part_names")
except KeyError as err:
if not parts:
raise KeyError(f"Cannot find bodyparts in {file}.") from err
raise KeyError(f"Cannot find bodyparts in {self.config_file}.") from err
return parts

def class_int2str(self, data: pd.DataFrame) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
if self.config_file.stem == "confmap_config": # SLEAP
with open(self.config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {self.config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
return data

@classmethod
def get_config_file(
cls,
config_file_dir: Path,
config_file_names: None | list[str] = None,
) -> Path:
def get_config_file(cls, config_file_dir: Path, config_file_names: None | list[str] = None) -> Path:
"""Returns the config file from a model's config directory."""
if config_file_names is None:
config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list)
Expand All @@ -375,22 +407,6 @@ def get_config_file(
raise FileNotFoundError(f"Cannot find config file in {config_file_dir}")
return config_file

@classmethod
def class_int2str(cls, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
config_file = cls.get_config_file(config_file_dir)
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["class"] == i, "class"] = subj
return data


def from_dict(data, pattern=None):
reader_type = data.get("type", None)
Expand Down

0 comments on commit cb7d84c

Please sign in to comment.