-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsuppfig_end_session_histogram.py
79 lines (66 loc) · 3.01 KB
/
suppfig_end_session_histogram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
HISTOGRAM OF SESSION END STATUSES DURING TRAINING
Miles Wells, UCL, 2019
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import datajoint as dj
from ibl_pipeline import acquisition
from paper_behavior_functions import \
(figpath, query_sessions, query_subjects, group_colors, seaborn_style,
FIGURE_HEIGHT, FIGURE_WIDTH, QUERY)
assert QUERY, 'This script requires a DataJoint instance, which was removed in Dec 2023.'
# Set default figure size.
save_path = figpath() # Our figure save path
colors = group_colors()
seaborn_style()
endcriteria = dj.create_virtual_module('SessionEndCriteriaImplemented',
'group_shared_end_criteria')
sessions = query_sessions().proj(session_start_date='date(session_start_time)')
subj_crit = query_subjects().aggr(
acquisition.Session(),
first_day='min(date(session_start_time))').proj('first_day')
session_num = (sessions * subj_crit).proj(n='DATEDIFF(session_start_date, first_day)')
df = (endcriteria.SessionEndCriteriaImplemented * session_num).fetch(format='frame') # Fetch data
# Convert statuses to numerical
fig, ax = plt.subplots(1, 1, figsize=(FIGURE_WIDTH/4, FIGURE_HEIGHT))
ids = {k: v for v, k in enumerate(df['end_status'].unique())}
df['end_status_id'] = df['end_status'].map(ids)
bins = [0, 6, 13, 20, 27, 34]
ax = df.pivot(columns='end_status_id').n.plot(
kind='hist', color=colors, bins=bins, stacked=True, density=True) # weights=369,
ax.legend(ids.keys())
ax.set_xlabel('Session #')
ax.set_ylabel('Frequency')
plt.gcf().savefig(os.path.join(save_path, "suppfig_end_status_histogram.png"))
# Unity plot
max_n_days = 40
normalize = True
df = df.reset_index()
counts = np.array([[sum(df['end_status_id'].where(df['n'] == n_days) == criterion)
if n_days < max_n_days
else sum(df['end_status_id'].where(df['n'] >= n_days) == criterion)
for n_days in range(max_n_days+1)]
for criterion in np.sort(df['end_status_id'].unique())])
if normalize:
counts = np.stack([n / sum(n) for n in counts.T]).T
# counts = np.stack([n / sum(n) for n in counts])
bar_l = range(1, counts.shape[1]+1)
# bottom = np.zeros_like(bar_l).astype('float')
bottom = np.vstack((np.zeros((1, counts.shape[1])), np.cumsum(counts, axis=0)[:-1, :]))
fig, ax = plt.subplots(1, 1, figsize=(FIGURE_WIDTH / 2, FIGURE_HEIGHT))
for i in range(counts.shape[0]):
ax.bar(bar_l, counts[i, :], bottom=bottom[i, :], width=1, label=list(ids.keys())[i],
color=colors[i])
ax.set_xticks([1] + [i * 7 for i in range(1, round(max_n_days+7/7))])
ax.set_xticks([0, 10, 20, 30, 40])
ax.set_xlim([0, counts.shape[1]+.5])
ax.set_xlabel('Session #')
ax.set_ylabel('Proportion')
ax.legend(loc='upper right')
plt.tight_layout()
sns.despine(trim=False)
plt.gcf().savefig(os.path.join(save_path, "suppfig_end_status_histogram_normalized.png"), dpi=300)
plt.gcf().savefig(os.path.join(save_path, "suppfig_end_status_histogram_normalized.pdf"))