-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMouseAgingToolSuite.py
437 lines (340 loc) · 21.3 KB
/
MouseAgingToolSuite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
import streamlit as st
import pandas as pd
import statsmodels.stats.power as smp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import PowerNorm
from lifelines import KaplanMeierFitter
from math import ceil
# Load the lifespan data
data_path = 'C:\\Users\\ndsch\\Data\\ITP-Lifespan-Data\\ITP_processed_data\\ITP_2004-2017_concat.csv'
data = pd.read_csv(data_path)
# Load the mouse purchase data
purchase_data_path = 'C:\\Users\\ndsch\\Data\\ITP-Lifespan-Data\\Mouse_costs\\JAX_HET3_prices.csv'
purchase_data = pd.read_csv(purchase_data_path)
# Title for the Streamlit app
st.title('The Mouse Aging Tool Suite')
def power_analysis_page():
st.sidebar.header('Power Analysis')
alpha = st.sidebar.slider("Select Alpha", min_value=0.01, max_value=0.10, value=0.05, step=0.01)
st.sidebar.write("Here, the two-sample independent t-test with the function smp.TTestIndPower().solve_power from the statsmodels.stats.power module has been employed.")
power_range = (0.7, 0.90)
effect_size_range = (0.1, 0.3)
# Compute for males
male_power_values, male_effect_size_values, male_sample_sizes = compute_required_sample_size('m', alpha, power_range, effect_size_range)
# Compute for females
female_power_values, female_effect_size_values, female_sample_sizes = compute_required_sample_size('f', alpha, power_range, effect_size_range)
# Compute for both sexes combined
combined_power_values, combined_effect_size_values, combined_sample_sizes = compute_required_sample_size(None, alpha, power_range, effect_size_range)
vmin_value = min(male_sample_sizes.min(), female_sample_sizes.min(), combined_sample_sizes.min())
vmax_value = max(male_sample_sizes.max(), female_sample_sizes.max(), combined_sample_sizes.max())
norm = PowerNorm(gamma=0.5, vmin=vmin_value, vmax=vmax_value)
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 20))
# Male heatmap
sns.heatmap(male_sample_sizes, cmap='YlGnBu', annot=True, fmt=".0f",
cbar_kws={'label': 'Required Sample Size'},
ax=ax1,
yticklabels=[f"{x:.1%}" for x in male_effect_size_values],
xticklabels=[f"{x:.1%}" for x in male_power_values],
vmin=vmin_value, vmax=vmax_value,
norm=norm)
ax1.set_title("Males")
ax1.set_xlabel("Power (%)")
ax1.set_ylabel("Desired Effect Size (%)")
# Female heatmap
sns.heatmap(female_sample_sizes, cmap='YlGnBu', annot=True, fmt=".0f",
cbar_kws={'label': 'Required Sample Size'},
ax=ax2,
yticklabels=[f"{x:.1%}" for x in female_effect_size_values],
xticklabels=[f"{x:.1%}" for x in female_power_values],
vmin=vmin_value, vmax=vmax_value,
norm=norm)
ax2.set_title("Females")
ax2.set_xlabel("Power (%)")
ax2.set_ylabel("Desired Effect Size (%)")
# Combined heatmap
sns.heatmap(combined_sample_sizes, cmap='YlGnBu', annot=True, fmt=".0f",
cbar_kws={'label': 'Required Sample Size'},
ax=ax3,
yticklabels=[f"{x:.1%}" for x in combined_effect_size_values],
xticklabels=[f"{x:.1%}" for x in combined_power_values],
vmin=vmin_value, vmax=vmax_value,
norm=norm)
ax3.set_title("Combined Sexes")
ax3.set_xlabel("Power (%)")
ax3.set_ylabel("Desired Effect Size (%)")
plt.tight_layout()
st.pyplot(fig)
# Cost Estimation Page Code
def cost_estimation_page():
st.sidebar.subheader('Mouse Purchase Cost Calculator')
mice_per_dose = st.sidebar.number_input('Mice per dose', min_value=1, value=200)
doses_per_treatment = st.sidebar.number_input('Number of doses per treatment', min_value=1, value=1)
number_of_treatments = st.sidebar.number_input('Number of treatments', min_value=1, value=1)
pricing_method = st.sidebar.radio("Choose pricing method:", ["Age-based pricing", "Custom price"])
if pricing_method == "Age-based pricing":
# Age selection with multiple units
age_unit = st.sidebar.selectbox('Choose age unit for purchase:', ['days', 'weeks', 'months'])
if age_unit == 'days':
age_value = st.sidebar.number_input('Age in days at purchase', min_value=1, max_value=78*7)
elif age_unit == 'weeks':
age_value = st.sidebar.number_input('Age in weeks at purchase', min_value=1, max_value=78)
age_value *= 7 # Convert weeks to days
else:
age_value = st.sidebar.number_input('Age in months at purchase', min_value=1, max_value=int(78/4))
age_value *= 30.41666666 # Convert months to days (approximation)
# Find the corresponding week for pricing.
# We use the ceil function to ensure that if the age exceeds a week boundary,
# we consider the price of the next week.
corresponding_week = -(-age_value // 7)
# Extracting the price for the corresponding week
if corresponding_week > 78: # Ensure we don't go beyond the maximum available week
corresponding_week = 78
prices_for_week = purchase_data[purchase_data['Age (weeks)'] == corresponding_week]['Price'].values
if len(prices_for_week) > 0:
mouse_price = prices_for_week[0]
else:
st.sidebar.warning(f"No price available for {corresponding_week} weeks. Please choose a different age.")
return
else:
mouse_price = st.sidebar.number_input('Enter custom price per mouse ($)', min_value=0.0, value=20.0)
total_mice_needed = mice_per_dose * doses_per_treatment * number_of_treatments
total_purchase_cost = total_mice_needed * mouse_price
st.sidebar.write(f"Total mice needed: {total_mice_needed}")
st.sidebar.write(f"Cost per mouse: ${mouse_price:.2f}")
st.sidebar.write(f"Total cost to purchase all mice: ${total_purchase_cost:,.2f}")
st.sidebar.subheader('Mouse Housing Cost Calculator')
mice_per_cage = st.sidebar.number_input('Enter number of mice per cage', min_value=1, value=5)
# Input for days to house mice with option to choose unit
time_unit = st.sidebar.selectbox('Choose time unit:', ['days', 'weeks', 'months'])
if time_unit == 'days':
time_value = st.sidebar.number_input('Enter number of days to house mice', min_value=1, value=30)
elif time_unit == 'weeks':
time_value = st.sidebar.number_input('Enter number of weeks to house mice', min_value=1, value=4)
time_value *= 7 # Convert weeks to days
else:
time_value = st.sidebar.number_input('Enter number of months to house mice', min_value=1, value=1)
time_value *= 30.41666666 # Convert months to days (approximation)
rounded_time_value = round(time_value)
st.sidebar.write(f"This corresponds to {rounded_time_value} days.")
daily_cage_cost = st.sidebar.number_input('Enter daily cage cost ($)', min_value=0.0, value=1.50)
total_cages = -(-total_mice_needed // mice_per_cage)
housing_cost = total_cages * daily_cage_cost * time_value
st.sidebar.write(f"Total cost to house mice: ${housing_cost:,.2f}")
grand_total = total_purchase_cost + housing_cost
st.sidebar.subheader(f'Grand Total: ${grand_total:,.2f}')
# Plotting total cost vs cost per mouse
mouse_price_range = np.linspace(0, 1000, 200)
total_costs_purchase = mouse_price_range * total_mice_needed
fig, ax1 = plt.subplots(figsize=(10, 6))
ax1.plot(mouse_price_range, total_costs_purchase, color='b')
ax1.set_title("Total Cost vs. Cost Per Mouse")
ax1.set_xlabel("Cost Per Mouse ($)")
ax1.set_ylabel("Total Cost ($)")
ax1.grid(True)
st.pyplot(fig)
# Plotting total housing cost vs daily cage cost
daily_cage_cost_range = np.linspace(0, 14, 200) # Assuming a reasonable range for daily cage costs
total_costs_housing = total_cages * daily_cage_cost_range * time_value
fig, ax2 = plt.subplots(figsize=(10, 6))
ax2.plot(daily_cage_cost_range, total_costs_housing, color='r')
ax2.set_title("Total Housing Cost vs. Daily Cage Cost")
ax2.set_xlabel("Daily Cage Cost ($)")
ax2.set_ylabel("Total Housing Cost ($)")
ax2.grid(True)
st.pyplot(fig)
# Plotting the heatmap of total cost based on daily cage cost and mouse cost
daily_cage_cost_range = np.arange(0.5, 12.1, 0.5) # $2, $4, ..., $20
mouse_price_range = np.arange(25, 501, 25) # $50, $100, ..., $500
# Initialize an empty matrix to store total costs
total_costs = np.zeros((len(mouse_price_range), len(daily_cage_cost_range)))
# Populate the matrix with total cost values
for i, mouse_cost in enumerate(mouse_price_range):
for j, cage_cost in enumerate(daily_cage_cost_range):
purchase_cost = mouse_cost * total_mice_needed
housing_cost = total_cages * cage_cost * time_value
total_costs[i, j] = (purchase_cost + housing_cost) / 1_000_000 # Convert to millions
fig, ax3 = plt.subplots(figsize=(10, 6))
sns.heatmap(total_costs[::-1], cmap='YlGnBu', annot=True, fmt='.1f',
xticklabels=daily_cage_cost_range, yticklabels=mouse_price_range[::-1], cbar_kws={'label': 'Total Cost (in Millions)'})
ax3.set_title("Heatmap of Total Cost")
ax3.set_xlabel("Daily Cage Cost ($)")
ax3.set_ylabel("Cost Per Mouse ($)")
plt.tight_layout()
st.pyplot(fig)
def compute_required_sample_size(sex_filter=None, alpha=0.05, power_range=(0.7, 0.9), effect_size_range=(0.1, 0.3)):
power_values = np.linspace(power_range[0], power_range[1], 30)
effect_size_values = np.linspace(effect_size_range[0], effect_size_range[1], 30)
# Select data based on sex_filter
if sex_filter is None:
lifespans = data[data['group'] == 'Control']['age(days)'] # Including both sexes
else:
lifespans = data[(data['sex'] == sex_filter) & (data['group'] == 'Control')]['age(days)']
mean_lifespan = lifespans.mean()
std_lifespan = lifespans.std()
sample_sizes = np.zeros((len(effect_size_values), len(power_values)))
for i, desired_effect_percentage in enumerate(effect_size_values):
effect_size_days = desired_effect_percentage * mean_lifespan
cohens_d = effect_size_days / std_lifespan
for j, power in enumerate(power_values):
sample_size = smp.TTestIndPower().solve_power(effect_size=cohens_d, alpha=alpha, power=power)
sample_sizes[i, j] = ceil(sample_size) # Round up to the next integer
return power_values, effect_size_values, sample_sizes
def bootstrap_page():
# Initial filter for Control treatment
control_data = data[data['treatment'] == 'Control']
def filter_data_by_sex_and_cohort(data, selected_sex, selected_cohorts):
# Adjust for the representation in the dataset
sex_mapping = {
"Male": "m",
"Female": "f"
}
# Filter by selected sex
data = data[data['sex'] == sex_mapping[selected_sex]]
# Filter by selected cohorts
if "All" not in selected_cohorts:
data = data[data['cohort'].isin(selected_cohorts)]
return data
def calculate_required_mice(n_mice, purchase_age, desired_age, filtered_data):
kmf = KaplanMeierFitter()
kmf.fit(filtered_data['age(days)'], event_observed=filtered_data['dead'], timeline=np.arange(0, max(filtered_data['age(days)'])+1, 1))
# Survival probability at purchase age
surv_prob_at_purchase = kmf.predict(purchase_age)
# Survival probability at desired age
surv_prob_at_desired_age = kmf.predict(desired_age)
# Adjusted survival probability at desired age
adjusted_surv_prob = surv_prob_at_desired_age / surv_prob_at_purchase
# Calculate required mice
required_mice = ceil(n_mice / adjusted_surv_prob)
return required_mice, adjusted_surv_prob
#original function that included censored data
#def plot_survival_curve(ax, data, sex):
# kmf = KaplanMeierFitter()
# kmf.fit(data['age(days)'], event_observed=data['dead'])
# kmf.plot(ax=ax, label=f'{sex} (n={len(data)}) Median: {int(kmf.median_survival_time_)} days')
#functions for both including or excluding censored data
def plot_survival_curve(ax, data, sex):
kmf = KaplanMeierFitter()
# Original Kaplan-Meier curve (including censored data)
# kmf.fit(data['age(days)'], event_observed=data['dead'])
#kmf.plot(ax=ax, label=f'{sex} (n={len(data)}) Median: {int(kmf.median_survival_time_)} days - Including Censored Data')
# Excluding censored data
uncensored_data = data[data['dead'] == True]
kmf.fit(uncensored_data['age(days)'], event_observed=uncensored_data['dead'])
kmf.plot(ax=ax, label=f'{sex} (n={len(uncensored_data)}) Median: {int(kmf.median_survival_time_)} days')
def bootstrap_estimate(n_mice, required_mice, purchase_age, desired_age, data, n_iterations=10):
required_mice_samples = []
median_survivals = [] # To store median survivals
datasets = [] # To store bootstrapped datasets
for _ in range(n_iterations):
sample_data = data[data['age(days)'] >= purchase_age].sample(n=int(required_mice), replace=True)
# Calculate the number of mice that reached the desired age in this sample
mice_reaching_desired_age = sum(sample_data['age(days)'] >= desired_age)
required_mice_sample = n_mice / (mice_reaching_desired_age / required_mice)
required_mice_samples.append(required_mice_sample)
kmf = KaplanMeierFitter()
kmf.fit(sample_data['age(days)'], event_observed=sample_data['dead'])
median_survivals.append(kmf.median_survival_time_)
datasets.append(sample_data)
# Identify datasets with best and worst median survival
best_dataset = datasets[np.argmax(median_survivals)]
worst_dataset = datasets[np.argmin(median_survivals)]
return pd.Series(required_mice_samples), best_dataset, worst_dataset
def main():
st.title("Mouse Sample Size Forecaster")
# Add the introduction paragraph
st.markdown("""Given a specific number of mice needed at a desired age, this app calculates the number of mice you should initially start with. This prediction is based on survival data from control mice (HET3s) used by the Interventions Testing Program (ITP) from cohorts 2004-2017. You may optionally select an individual cohort or any combination of cohorts instead of all cohorts. Note that the 2017 has unusually low survival, so it may be unwise (or perhaps wise if one wants to be conservative) to use that cohort.""")
# Multi-select dropdown for selecting cohort(s)
available_cohorts = ["All"] + sorted(control_data['cohort'].unique().tolist())
selected_cohorts = st.multiselect("Select Cohort(s):", available_cohorts, default=["All"])
# User Inputs using input boxes
purchase_age = st.number_input("Enter Starting Age of Mice (in days):", value=100)
desired_age = st.number_input("Enter Desired Age of Survival for Mice (in days):", value=913)
n_mice = st.number_input("Enter Number of Mice You Want at Desired Age:", value=8)
if "calculate_pressed" not in st.session_state:
st.session_state.calculate_pressed = False
if st.button("Calculate") or st.session_state.calculate_pressed:
# Initial estimate calculations
results = []
fig, ax = plt.subplots(figsize=(10, 7))
required_mice_dict = {} # A dictionary to store required mice for each sex
for sex_option in ["Male", "Female"]:
filtered_data = filter_data_by_sex_and_cohort(control_data, sex_option, selected_cohorts)
required_mice, _ = calculate_required_mice(n_mice, purchase_age, desired_age, filtered_data)
required_mice_dict[sex_option] = np.ceil(required_mice) # Rounding up and storing for each sex
results.append([sex_option, ceil(required_mice)])
plot_survival_curve(ax, filtered_data, sex_option)
ax.set_title('Survival Data Used to Calculate Number of Required Mice')
ax.set_ylabel('Survival Probability')
ax.set_xlabel('Days')
ax.legend()
st.pyplot(fig)
result_df = pd.DataFrame(results, columns=["Sex", "Required Mice"])
result_df.set_index("Sex", inplace=True)
st.table(result_df)
st.session_state.calculate_pressed = True
if "calculate_pressed" in st.session_state:
bootstrap_cycles = st.number_input("Number of Bootstrap Cycles:", value=10)
if st.button("Bootstrap"):
# Bootstrapping
bootstrap_results = []
st.markdown("""
Bootstrapping is a technique used to simulate the variability one might expect in an actual experiment by drawing repeated random samples from a dataset. In this case, we are using the survival data, as visualized in the Kaplan-Meier curve above, to understand potential variation in our experiment outcomes.
For our bootstrap analysis, we take a random subsample from the selected data, with a size equivalent to our initial estimate of required mice (as calculated above). This acts as a simulated experiment, where we start with the estimated number of mice and observe how many reach the desired age. By doing this thousands of times, we simulate the experiment under many different scenarios to understand the potential range of outcomes.
""")
# Setting up a figure for histograms
fig_hist, axes_hist = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(10, 14)) # Setting up a figure for Kaplan-Meier curves
for idx, sex_option in enumerate(["Male", "Female"]):
filtered_data = filter_data_by_sex_and_cohort(control_data, sex_option, selected_cohorts)
bootstrap_series, best_data, worst_data = bootstrap_estimate(n_mice, required_mice_dict[sex_option], purchase_age, desired_age, filtered_data, n_iterations=bootstrap_cycles)
bootstrap_results.append([
sex_option,
ceil(bootstrap_series.median()),
ceil(bootstrap_series.quantile(0.80)),
ceil(bootstrap_series.quantile(0.90)),
ceil(bootstrap_series.quantile(0.95)),
ceil(bootstrap_series.quantile(0.99))
])
# Plotting histogram using the stored results
bootstrap_min = int(bootstrap_series.min())
bootstrap_max = int(bootstrap_series.max())
axes_hist[idx].hist(bootstrap_series, bins=20, color='skyblue', edgecolor='black')
# alternatively use a step size of one in the above line, but note that b/c of the discrete nature of survival probabilities, the histograms will have gaps at certain discrete values:
# bins=range(bootstrap_min, bootstrap_max + 1)
axes_hist[idx].axvline(bootstrap_series.median(), color='red', linestyle='dashed', linewidth=1)
axes_hist[idx].set_title(f'{sex_option} Required Mice Distribution')
axes_hist[idx].set_xlabel('Required Mice')
axes_hist[idx].set_ylabel('Frequency')
min_ylim, max_ylim = plt.ylim()
# Plotting the Kaplan-Meier curves for best and worst median survival for each sex
kmf = KaplanMeierFitter()
kmf.fit(best_data['age(days)'], event_observed=best_data['dead'])
kmf.plot(ax=axes[idx], label=f'Best Median Survival (n={len(best_data)}): {int(kmf.median_survival_time_)} days')
kmf = KaplanMeierFitter()
kmf.fit(worst_data['age(days)'], event_observed=worst_data['dead'])
kmf.plot(ax=axes[idx], label=f'Worst Median Survival (n={len(worst_data)}): {int(kmf.median_survival_time_)} days')
axes[idx].set_title(f'{sex_option} Kaplan-Meier Curves for Best and Worst Median Survival from Bootstrapping')
axes[idx].set_xlim([purchase_age, desired_age]) # This line adjusts the x-axis range
axes[idx].set_ylabel('Survival Probability')
axes[idx].set_xlabel('Days')
axes[idx].legend()
# Create and show the table
bootstrap_df = pd.DataFrame(bootstrap_results, columns=["Sex", "Median", "80th", "90th", "95th", "99th"])
bootstrap_df.set_index("Sex", inplace=True)
st.table(bootstrap_df)
# Show the histograms
st.pyplot(fig_hist)
# Show the Kaplan-Meier curves
st.pyplot(fig)
if __name__ == "__main__":
main()
# Page selection
page = st.sidebar.radio("Choose a page:", ["Power Analysis & Heatmaps", "Mouse Cost Calculator", "Bootstrap Survival Estimation"])
if page == "Power Analysis & Heatmaps":
power_analysis_page()
elif page == "Mouse Cost Calculator":
cost_estimation_page()
elif page == "Bootstrap Survival Estimation":
bootstrap_page()