diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 43d71f84..ef50138f 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -35,7 +35,7 @@ def make(self, key): """On a per-chunk basis, check for the presence of new block, insert into Block table.""" # find the 0s # that would mark the start of a new block - # if the 0 is the first index - look back at the previous chunk + # In the BlockState data - if the 0 is the first index - look back at the previous chunk # if the previous timestamp belongs to a previous epoch -> block_end is the previous timestamp # else block_end is the timestamp of this 0 chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") @@ -55,17 +55,31 @@ def make(self, key): key["experiment_name"], previous_block_start, chunk_end ) - block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction - block_state_df = fetch_stream(block_state_query)[previous_block_start:chunk_end] + # detecting block end times + # pellet count reset - find 0s in BlockState - block_ends = block_state_df[block_state_df.pellet_ct.diff() < 0] + block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction + block_state_df = fetch_stream(block_state_query) + block_state_df.index = block_state_df.index.round( + "us" + ) # timestamp precision in DJ is only at microseconds + block_state_df = block_state_df.loc[ + (block_state_df.index > previous_block_start) & (block_state_df.index <= chunk_end) + ] + + block_ends = block_state_df[block_state_df.pellet_ct == 0] + # account for the double 0s - find any 0s that are within 1 second of each other, remove the 2nd one + double_0s = block_ends.index.to_series().diff().dt.total_seconds() < 1 + # find the indices of the 2nd 0s and remove + double_0s = double_0s.shift(-1).fillna(False) + block_ends = block_ends[~double_0s] block_entries = [] for idx, block_end in enumerate(block_ends.index): if idx == 0: if previous_block_key: # if there is a previous block - insert "block_end" for the previous block - previous_pellet_time = block_state_df[:block_end].index[-2] + previous_pellet_time = block_state_df[:block_end].index[-1] previous_epoch = ( acquisition.Epoch.join(acquisition.EpochEnd, left=True) & exp_key @@ -233,6 +247,10 @@ def make(self, key): } ) + # update block_end if last timestamp of encoder_df is before the current block_end + if encoder_df.index[-1] < block_end: + block_end = encoder_df.index[-1] + # Subject data # Get all unique subjects that visited the environment over the entire exp; # For each subject, see 'type' of visit most recent to start of block @@ -248,6 +266,7 @@ def make(self, key): _df = subject_visits_df[subject_visits_df.id == subject_name] if _df.type[-1] != "Exit": subject_names.append(subject_name) + for subject_name in subject_names: # positions - query for CameraTop, identity_name matches subject_name, pos_query = ( @@ -291,6 +310,14 @@ def make(self, key): } ) + # update block_end if last timestamp of pos_df is before the current block_end + if pos_df.index[-1] < block_end: + block_end = pos_df.index[-1] + + if block_end != (Block & key).fetch1("block_end"): + Block.update1({**key, "block_end": block_end}) + self.update1({**key, "block_duration": (block_end - block_start).total_seconds() / 3600}) + @schema class BlockSubjectAnalysis(dj.Computed): @@ -501,7 +528,7 @@ def make(self, key): @schema class BlockPlots(dj.Computed): - definition = """ + definition = """ -> BlockAnalysis --- subject_positions_plot: longblob