Skip to content

Commit

Permalink
chore: ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
ttngu207 committed Dec 12, 2024
1 parent 642589d commit 0910695
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
28 changes: 14 additions & 14 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ class Patch(dj.Part):
---
in_patch_timestamps: longblob # timestamps when a subject is at a specific patch
in_patch_time: float # total seconds spent in this patch for this block
in_patch_rfid_timestamps=null: longblob # timestamps when a subject is at a specific patch based on RFID
in_patch_rfid_timestamps=null: longblob # in_patch_timestamps based on RFID
pellet_count: int
pellet_timestamps: longblob
patch_threshold: longblob # patch threshold value at each pellet delivery
Expand Down Expand Up @@ -528,8 +528,8 @@ def make(self, key):

# subject-rfid mapping
rfid2subj_map = {
int(l): s
for s, l in zip(
int(lab_id): subj_name
for subj_name, lab_id in zip(
*(subject.SubjectDetail.proj("lab_id") & f"subject in {tuple(subject_names)}").fetch(
"subject", "lab_id"
),
Expand Down Expand Up @@ -1183,11 +1183,11 @@ def calculate_running_preference(group, pref_col, out_col):
pel_patches = [p for p in patch_names if "dummy" not in p.lower()] # exclude dummy patches
data = []
for patch in pel_patches:
for subject in subject_names:
for subject_name in subject_names:
data.append(
{
"patch_name": patch,
"subject_name": subject,
"subject_name": subject_name,
"time": wheel_ts[patch],
"weighted_dist": np.empty_like(wheel_ts[patch]),
}
Expand Down Expand Up @@ -1287,18 +1287,18 @@ def norm_inv_norm(group):
df = subj_wheel_pel_weighted_dist
# Iterate through patches and subjects to create plots
for i, patch in enumerate(pel_patches, start=1):
for j, subject in enumerate(subject_names, start=1):
for j, subject_name in enumerate(subject_names, start=1):
# Filter data for this patch and subject
times = df.loc[patch].loc[subject]["time"]
norm_values = df.loc[patch].loc[subject]["norm_value"]
wheel_prefs = df.loc[patch].loc[subject]["wheel_pref"]
times = df.loc[patch].loc[subject_name]["time"]
norm_values = df.loc[patch].loc[subject_name]["norm_value"]
wheel_prefs = df.loc[patch].loc[subject_name]["wheel_pref"]

# Add wheel_pref trace
weighted_patch_pref_fig.add_trace(
go.Scatter(
x=times,
y=wheel_prefs,
name=f"{subject} - wheel_pref",
name=f"{subject_name} - wheel_pref",
line={
"color": subject_colors[i - 1],
"dash": patch_linestyles_dict[patch],
Expand All @@ -1316,7 +1316,7 @@ def norm_inv_norm(group):
go.Scatter(
x=times,
y=norm_values,
name=f"{subject} - norm_value",
name=f"{subject_name} - norm_value",
line={
"color": subject_colors[i - 1],
"dash": patch_linestyles_dict[patch],
Expand Down Expand Up @@ -1846,8 +1846,8 @@ def get_foraging_bouts(
# - For the foraging bout end time, we need to account for the final pellet delivery time
# - Filter out events with < `min_pellets`
# - For final events, get: duration, n_pellets, cum_wheel_distance -> add to returned DF
for subject in subject_patch_data.index.unique("subject_name"):
cur_subject_data = subject_patch_data.xs(subject, level="subject_name")
for subject_name in subject_patch_data.index.unique("subject_name"):
cur_subject_data = subject_patch_data.xs(subject_name, level="subject_name")
n_pels = sum([arr.size for arr in cur_subject_data["pellet_timestamps"].values])
if n_pels < min_pellets:
continue
Expand Down Expand Up @@ -1929,7 +1929,7 @@ def get_foraging_bouts(
"end": bout_starts_ends[:, 1],
"n_pellets": bout_pellets,
"cum_wheel_dist": bout_cum_wheel_dist,
"subject": subject,
"subject": subject_name,
}
),
]
Expand Down
17 changes: 14 additions & 3 deletions aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from aeon.dj_pipeline.analysis.block_analysis import *
"""Script to update in_patch_rfid_timestamps for all blocks that are missing it."""

import datajoint as dj

from aeon.dj_pipeline import acquisition, fetch_stream, streams, subject
from aeon.dj_pipeline.analysis.block_analysis import Block, BlockAnalysis, BlockSubjectAnalysis

logger = dj.logger


def update_in_patch_rfid_timestamps(block_key):
"""Update in_patch_rfid_timestamps for a given block_key.
Args:
block_key (dict): block key
"""
logger.info(f"Updating in_patch_rfid_timestamps for {block_key}")

block_key = (Block & block_key).fetch1("KEY")
Expand All @@ -15,8 +25,8 @@ def update_in_patch_rfid_timestamps(block_key):
subject_names = (BlockAnalysis.Subject & block_key).fetch("subject_name")

rfid2subj_map = {
int(l): s
for s, l in zip(
int(lab_id): subj_name
for subj_name, lab_id in zip(
*(subject.SubjectDetail.proj("lab_id") & f"subject in {tuple(subject_names)}").fetch(
"subject", "lab_id"
),
Expand Down Expand Up @@ -55,6 +65,7 @@ def update_in_patch_rfid_timestamps(block_key):


def main():
"""Update in_patch_rfid_timestamps for all blocks that are missing it."""
block_keys = BlockSubjectAnalysis & (
BlockSubjectAnalysis.Patch & "in_patch_rfid_timestamps IS NULL"
).fetch("KEY")
Expand Down

0 comments on commit 0910695

Please sign in to comment.