-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataproccessor.py
674 lines (499 loc) · 24.1 KB
/
dataproccessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
import math
import urllib
import os
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import seaborn as sns
import numpy as np
import pandas as pd
import coronatracker as ct
def make_plots():
"""
Generates the plots displayed in tweets posted by this bot.
"""
print('Attempting to build plots!')
ct.logger.info('Making plots...')
make_summary_bar_plot()
ts_data = ct.get_time_series()
global_data = get_global_time_series()
# Time Series Cumulative Plot
freqs = get_country_cumulative(ts_data)
make_time_series_plot(freqs)
# Daily Change Plot
changes = get_total_daily_change(ts_data)
make_daily_change_plot(changes)
# Country Cumulative Case Comparison Plot
comp_countries = find_metric_leader(global_data, inc_US=True, size=8)
cumulative_cases = get_country_cumulative(global_data, countries=comp_countries)
make_comparison_plot(comp_countries, cumulative_cases)
make_state_death_plot()
make_per_capita_plot()
make_testing_plot()
make_vent_icu_plot()
ct.logger.info('Created plots!')
def make_summary_bar_plot():
us_frame = ct.get_jhu_data()
us_frame.sort_values(by='cases', axis=0, ascending=False, inplace=True)
for jhu_tick in plt.xticks()[1]:
jhu_tick.set_rotation(45)
plt.suptitle('COVID-19 Details for the United States')
# Plot setup for JHU figure
state_frame = ct.make_state_frame(us_frame)
state_frame = state_frame.iloc[np.arange(0, 26), [0, 1, 2, 3]]
pos = np.arange(len(state_frame['state']))
width = 0.9
plt.gcf().set_size_inches(14, 14)
case_bar = plt.bar(pos, state_frame['cases'], width, label='Cases')
death_bar = plt.bar(pos, state_frame['deaths'], width, label='Deaths')
recov_bar = plt.bar(pos, state_frame['recoveries'], width,
bottom=state_frame['deaths'], label='Recoveries')
plt.xlabel('State')
plt.ylabel('Count')
plt.title(f'Confirmed COVID-19 Case Statistics for the Top 25 States by Caseload')
plt.xticks(pos, state_frame['state'].tolist(), fontsize=10)
plt.legend((case_bar[0], death_bar[0], recov_bar[0]), ('Cases', 'Deaths', 'Recoveries'), loc='upper right')
plt.savefig(ct.plot_path + 'state_sum.png')
plt.show(block=False)
plt.close()
def make_time_series_plot(freqs: list):
"""
Creates a simple line plot of the number of cases for each day of the outbreak
:param freqs: The number of cases for each day of the outbreak
"""
days = len(freqs)
fig, (reg_ax, log_ax) = plt.subplots(1, 2)
fig.set_size_inches(12, 10)
fig.suptitle('Cumulative Cases per Day in the United States (Standard Scale and Natural Log)')
reg_ax.plot(np.arange(start=1, stop=days + 1), freqs, color='red')
reg_ax.set_xlabel('Days Since 01/21/2020')
reg_ax.set_ylabel('Number of cases')
log_freqs = [math.log(freq) for freq in freqs]
log_ax.plot(np.arange(start=1, stop=days + 1), log_freqs, color='red')
log_ax.set_xlabel('Days Since 01/21/2020')
log_ax.set_ylabel('Number of cases (Natural Log Scale)')
plt.savefig(ct.plot_path + 'rate_plot.png')
plt.show(block=False)
plt.close()
def make_daily_change_plot(changes: []):
"""
Plots the daily change in cases
:param changes: A list of the changes that have occured each day
"""
# Plus 1 because numpy stops 1 before the max value
days = len(changes) + 1
plt.plot(np.arange(1, stop=days), changes, color='red')
plt.title('Daily Change in Cases Since 01/22/2020')
plt.gcf().set_size_inches(12, 12)
plt.xlabel('Days since 01/22/2020')
plt.ylabel('Change in Cases from Previous Day')
plt.savefig(ct.plot_path + 'change_plot.png')
plt.show(block=False)
plt.close()
def make_comparison_plot(countries: list, cumulative_cases: list):
"""
Creates a line plot comparing U.S Cumulative Cases to other countries
:param countries: The name of the countries to plot
:param cumulative_cases: The cumulative number of cases for each country
"""
change_dict = {country: change for country, change in zip(countries, cumulative_cases)}
days = np.arange(start=0, stop=len(change_dict['US'])).tolist() * 8
change_frame = pd.DataFrame(change_dict).melt(var_name='country', value_name='cases')
change_frame['day'] = days
sns.set()
sns.lineplot(x='day', y='cases', hue='country', data=change_frame)
plt.title('Cumulative Cases in the Top 10 Countries by Cases Globally')
plt.gcf().set_size_inches(12, 12)
plt.xlabel('Days since 01/22/2020')
plt.ylabel('Cumulative Cases')
plt.savefig(ct.plot_path + 'comp_plot.png')
plt.show(block=False)
plt.close()
def get_country_cumulative(data: pd.DataFrame, countries='US') -> list:
"""
Calculates the cumulative of new cases for the U.S or a selected group of countries each day
:param countries: Either a single country or a list of countries
:param data: A DataFrame containing time series data for cases
:return: A list of the number of cases for each day
"""
if countries == 'US':
rates = [data[column].tolist() for column in data.columns if '/20' in column]
daily_rates = [np.sum(day_rate) for day_rate in rates]
return daily_rates
elif type(countries) is list:
cumulative_cases = []
for country in countries:
is_country = data['Country_Region'] == country
country_frame = data[is_country]
rates = [country_frame[column].tolist() for column in country_frame.columns if '/20' in column]
daily_rates = [np.sum(day_rate) for day_rate in rates]
cumulative_cases.append(daily_rates)
return cumulative_cases
else:
is_country = data['Country_Region'] == countries
country_frame = data[is_country]
rates = [country_frame[column].tolist() for column in country_frame.columns if '/20' in column]
daily_rates = [np.sum(day_rate) for day_rate in rates]
return daily_rates
def get_total_daily_change(data: pd.DataFrame, country='US') -> list:
"""
Calculates the change in either cases or deaths for each day since the time first recorded in the data
:param country: The country to get the changes for. Default is US
:param data: a DataFrame containing time series data
:return: The change in cases for each day
"""
if country == 'US':
# Select columns with dates
columns = [column for column in data.columns.to_list() if '/20' in column]
changes = []
for index in range(0, len(columns)):
if index != len(columns):
index = index + 1
needed_columns = columns[:index]
changes.append(ct.get_daily_change(data[needed_columns]))
else:
changes.append(ct.get_daily_change(data))
return changes
else:
country_data = data[data['Country_Region'] == country]
columns = [column for column in country_data.columns.to_list() if '/20' in column]
changes = []
changes = []
for index in range(3, len(columns)):
if index != len(columns):
index = index + 1
needed_columns = columns[2:index]
changes.append(ct.get_daily_change(country_data[needed_columns]))
else:
changes.append(ct.get_daily_change(country_data))
return changes
def get_global_time_series() -> pd.DataFrame:
"""Downloads and saves the time series file without filtering for only U.S data"""
if os.path.exists(ct.jhu_path + 'jhu_global_time.csv'):
data = pd.read_csv(ct.jhu_path + 'jhu_global_time.csv')
newest_date = datetime.strptime(data.columns.to_list()[-1], '%m/%d/%y')
if newest_date == datetime.now().date():
ct.logger.info('Currently downloaded global time series is up to date. Reading file...')
return data
else:
ct.logger.info('Global time series data may be out of date! Trying to download new file...')
file_link = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv'
newest_csv_req = urllib.request.Request(file_link)
csv_file = urllib.request.urlopen(newest_csv_req)
csv_data = csv_file.read()
with open(ct.jhu_path + 'jhu_global_time_temp.csv', 'wb') as file:
file.write(csv_data)
global_frame = pd.read_csv(ct.jhu_path + 'jhu_global_time_temp.csv')
global_frame.drop(columns=['Lat', 'Long'], inplace=True)
global_frame.rename(columns={'Province/State': 'Province_State', 'Country/Region': 'Country_Region'},
inplace=True)
global_frame.to_csv(ct.jhu_path + 'jhu_global_time.csv')
os.remove(ct.jhu_path + 'jhu_global_time_temp.csv')
ct.logger.info('Succesfully downloaded time series data!')
return global_frame
ct.logger.info('Global time series data may be out of date! Trying to download new file...')
file_link = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv'
newest_csv_req = urllib.request.Request(file_link)
csv_file = urllib.request.urlopen(newest_csv_req)
csv_data = csv_file.read()
with open(ct.jhu_path + 'jhu_global_time_temp.csv', 'wb') as file:
file.write(csv_data)
global_frame = pd.read_csv(ct.jhu_path + 'jhu_global_time_temp.csv')
global_frame.drop(columns=['Lat', 'Long'], inplace=True)
global_frame.to_csv(ct.jhu_path + 'jhu_global_time.csv')
os.remove(ct.jhu_path + 'jhu_global_time_temp.csv')
ct.logger.info('Succesfully downloaded time series data!')
return global_frame
def get_death_time_series(country='all') -> pd.DataFrame:
"""
Downloads the global deaths time series. Can return for a specific country or all countries
:param: country: Either 'all' for all countries or a specific country's name. Default is all.
:return: A DataFrame containing the death time series for a given country or all countries
"""
if country == 'US':
file_path = ct.jhu_path + 'jhu_death_time_us.csv'
file_link = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_US.csv'
drop_list = ['UID', 'iso2', 'iso3', 'code3', 'FIPS', 'Lat', 'Long_', 'Combined_Key']
else:
file_path = ct.jhu_path + 'jhu_death_time_global.csv'
file_link = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv'
drop_list = ['Lat', 'Long']
if os.path.exists(file_path):
data = pd.read_csv(file_path)
newest_date = datetime.strptime(data.columns.to_list()[-1], '%m/%d/%y')
if newest_date == datetime.now().date():
ct.logger.info('Currently downloaded global death time series is up to date. Reading file...')
return data
else:
ct.logger.info('Global death time series data may be out of date! Trying to download new file...')
newest_csv_req = urllib.request.Request(file_link)
csv_file = urllib.request.urlopen(newest_csv_req)
csv_data = csv_file.read()
with open(ct.jhu_path + 'jhu_time_temp.csv', 'wb') as file:
file.write(csv_data)
death_frame = pd.read_csv(ct.jhu_path + 'jhu_time_temp.csv')
death_frame.drop(columns=drop_list, inplace=True)
death_frame.to_csv(file_path)
os.remove(ct.jhu_path + 'jhu_time_temp.csv')
ct.logger.info('Succesfully downloaded death time series data!')
# JHU has some inconsistencies in column naming, so here we patch them up
if country == 'all':
death_frame.rename(columns={'Province/State': 'Province_State', 'Country/Region':
'Country_Region'}, inplace=True)
return death_frame
elif country == 'US':
death_frame.rename(columns={'Admin2': 'City_County'}, inplace=True)
return death_frame
else:
death_frame.rename(columns={'Province/State': 'Province_State', 'Country/Region':
'Country_Region'}, inplace=True)
return death_frame[death_frame['Country_Region'] == country]
ct.logger.info('Global death time series data may be out of date! Trying to download new file...')
newest_csv_req = urllib.request.Request(file_link)
csv_file = urllib.request.urlopen(newest_csv_req)
csv_data = csv_file.read()
with open(ct.jhu_path + 'jhu_time_temp.csv', 'wb') as file:
file.write(csv_data)
death_frame = pd.read_csv(ct.jhu_path + 'jhu_time_temp.csv')
death_frame.drop(columns=drop_list, inplace=True)
death_frame.to_csv(file_path)
os.remove(ct.jhu_path + 'jhu_time_temp.csv')
ct.logger.info('Succesfully downloaded death time series data!')
if country == 'all':
return death_frame
elif country == 'US':
death_frame.rename(columns={'Admin2': 'City_County'}, inplace=True)
return death_frame
else:
return death_frame[death_frame['Country_Region'] == country]
def find_metric_leader(global_data: pd.DataFrame, inc_US=False, size=1):
"""
Finds the country with the largest amount of cases or deaths, depending on the type of data provided
:param size: How many leaders should be returned? Default is 1
:param inc_US: Should the U.S be included as a potential leader? Default is no due to the U.S' current position
:param global_data: Global time series data
:return: The name of the leading country if size is 1, or a list of size {size} if size is greater than 1
"""
current_leader = 'NO COUNTRY FOUND'
newest_column = global_data.columns.to_list()[-1]
global_data.sort_values(by=newest_column, ascending=False, inplace=True)
leader_list = []
for country in global_data['Country_Region']:
if country == 'US' and inc_US is False:
pass
else:
if size == 1:
return country
else:
if len(leader_list) < size:
leader_list.append(country)
else:
return leader_list
def get_time_to_target(target=-1, metric_type='cases'):
"""
Tries to find the amount of time in days it will take for the U.S to arrive at the target number of cases based on
the mean change in cases over the past five days
:param metric_type: The target type. Either cases or deaths
:param target: The target number of cases if using cases.
:return:
"""
if metric_type == 'cases':
if target == -1:
raise ValueError('Must specify a target if type is set to cases!')
# Since the U.S is now the leader in cases, the time to leader mode for cases has been disabled
# Slope is calculated from the past three days to prevent skew from earlier time periods
us_changes = get_total_daily_change(ct.get_time_series())[-3:]
us_slope = np.mean(us_changes)
# leader = find_case_leader(global_data)
# leader_changes = get_total_daily_change(global_data, 'China')[-5:]
# leader_slope = np.mean(leader_changes)
est_us_cases = get_country_cumulative(ct.get_time_series())[-1]
# est_leader_cases = get_country_cumulative(get_global_time_series(), countries=leader)[-1]
add_days = 0
while est_us_cases <= target:
add_days += 0.1
est_us_cases += us_slope * 0.1
return round(add_days, 2)
if metric_type == 'deaths':
# Slope is calculated from the past three days to prevent skew from earlier time periods
us_changes = get_total_daily_change(get_death_time_series('US'))[-3:]
us_slope = np.mean(us_changes)
leader = find_metric_leader(get_death_time_series())
leader_changes = get_total_daily_change(get_death_time_series(country=leader), leader)[-3:]
leader_slope = np.mean(leader_changes)
est_us_deaths = get_country_cumulative(get_death_time_series('US'))[-1]
est_leader_deaths = get_country_cumulative(get_death_time_series(leader), countries=leader)[-1]
add_days = 0
if leader_slope > us_slope:
return -1
else:
while est_us_deaths <= est_leader_deaths:
add_days += 0.1
est_us_deaths += us_slope * 0.1
est_leader_deaths += leader_slope * 0.1
return round(add_days, 2)
def get_top_states_by_metric(metric: str, size: int) -> [str]:
"""
Gets the top number of states by a certain metric, either cases or deaths
:param metric: Either cases or deaths
:param size: The number of entries to return
:return: A list of size 'size' containing the top states for that metric
"""
if metric == 'cases':
data = ct.get_time_series()
elif metric == 'deaths':
data = get_death_time_series('US')
else:
raise ValueError("'metric' must be either 'cases' or 'deaths'!")
most_recent_column = data.columns.to_list()[-1]
states = []
data.sort_values(ascending=False, by=most_recent_column, inplace=True)
for state in data['Province_State']:
if state not in states and len(states) <= size:
states.append(state)
if len(states) == size:
break
return states
def make_state_death_plot():
"""
Creates a line plot of the cumulative death total for the top five U.S states
"""
death_data = get_death_time_series('US')
states = get_top_states_by_metric('deaths', 10)
state_counts = []
relevant_columns = [column for column in death_data.columns.to_list() if '/20' in column]
new_states = []
for state in states:
state_data = death_data[death_data['Province_State'] == state]
state_data = state_data[relevant_columns]
temp_state = [state] * len(relevant_columns)
new_states += temp_state
for column in state_data.columns:
day_sum = int(state_data[column].sum())
state_counts.append(day_sum)
# Have to rearange the data to make it suitable for the plot
dates = relevant_columns * len(states)
day_num = [num for num in range(len(relevant_columns))] * 10
plot_frame = pd.DataFrame({'date': dates, 'state': new_states, 'counts': state_counts, 'day_num': day_num})
# Finally start setting up the plot
sns.set()
sns.lineplot(x='day_num', y='counts', hue='state', data=plot_frame)
plt.title(f'Daily Death Totals for the Top 10 U.S States by Death Total On {ct.now.strftime("%m/%d/%y")}')
plt.xlabel('Days Since 01/22/2020')
plt.ylabel('Deaths')
plt.gcf().set_size_inches(12, 12)
plt.savefig(ct.plot_path + 'death_comp_plot.png')
plt.show(block=False)
plt.close()
def get_deaths_per_capita(multiplier=100000, size=5) -> list:
"""
Gets the top n U.S states per multiplier
:param multiplier: The multipler to use for the resulting rate. As in 'X deaths per {metric}'
:param size: The number of states to return. Default is five
:return: A list of the top {n} states per {multiplier}
"""
death_data = get_death_time_series('US')
most_recent_column = death_data.columns.to_list()[-1]
states = []
combinations = []
for state in death_data['Province_State']:
if state not in states:
states.append(state)
for state in states:
state_frame = death_data[death_data['Province_State'] == state]
total_pop = int(state_frame['Population'].sum())
total_deaths = int(state_frame[most_recent_column].sum())
if total_pop == 0:
pass
else:
# Per Multiplier
death_rate = round((total_deaths / total_pop) * multiplier, 2)
combinations.append((state, death_rate))
top_list = sorted(combinations, key=lambda state_pair: state_pair[1], reverse=True)
return top_list[:size]
def get_top_increasing_by_metric(metric='deaths', size=5) -> list:
"""
Finds the top {size} states by {metric}
:param metric: The metric to be measured. Either cases or deaths
:param size: How many states should be returned. Default is 5
:return: A list of the top {size} states by {metric}
"""
if metric == 'deaths':
death_frame = get_death_time_series('US')
req_columns = [column for column in death_frame.columns.to_list() if '/20' in column
or column == 'Province_State']
states = [state for state in death_frame['Province_State'].unique() if 'Princess' not in state]
death_frame = death_frame[req_columns]
print(states)
def make_per_capita_plot():
"""
Makes a plot of the top five states per capita by deaths
"""
top_5 = get_deaths_per_capita()
# Now for the plotting
plot_states = [entry[0] for entry in top_5]
plot_rates = [entry[1] for entry in top_5]
plt.gcf().set_size_inches(10, 10)
plt.title(f'Top 5 States by Deaths per 100,000 Population on {ct.now.strftime("%m/%d/%y")}')
plt.xlabel('State')
plt.ylabel('Deaths per 100,000 Population')
plt.bar(x=plot_states, height=plot_rates)
plt.savefig(ct.plot_path + 'capita_plot.png')
plt.show(block=False)
plt.close()
def make_testing_plot():
"""
Makes a bar plot fo the top five states by tests per 100,000 population
"""
test_data = ct.get_jhu_data().sort_values(by='test_rate', ascending=False)
plot_states = test_data['state'][:5]
plot_tests = test_data['test_rate'][:5]
plot_incidence = test_data['incidence'][:5]
plt.title(f'Top 5 States by Testing Rate and Their Incidence On {ct.now.strftime("%m/%d/%y")}')
plt.xlabel('State')
plt.ylabel('Measure Per 100,000 Population')
test_bar = plt.bar(x=plot_states, height=plot_tests, color='#ffcc00')
inc_bar = plt.bar(x=plot_states, height=plot_incidence, color='red')
plt.gcf().set_size_inches(10, 10)
plt.legend((test_bar[0], inc_bar[0]), ('Tests', 'Incidence'), loc='best')
plt.savefig(ct.plot_path + 'capita_rate.png')
plt.show(block=False)
plt.close()
def get_tracking_project_data() -> pd.DataFrame:
"""
Fetches historical data from the COVID-19 Tracking Project
https://covidtracking.com/
:return: A DataFrame containing the historical data from the COVID-19 Tracking Project
"""
if os.path.exists(ct.tracking_proj_path) is not True:
os.mkdir(ct.tracking_proj_path)
newest_csv_req = urllib.request.Request("https://covidtracking.com/api/v1/us/daily.csv")
csv_file = urllib.request.urlopen(newest_csv_req)
csv_data = csv_file.read()
with open(ct.tracking_proj_path + 'historical_data.csv', 'wb') as file:
file.write(csv_data)
tracking_frame = pd.read_csv(ct.tracking_proj_path + 'historical_data.csv')
return tracking_frame
def make_vent_icu_plot():
"""
Generates a line plot showing the test positivity rate over time for the US
"""
data = get_tracking_project_data()
# Filters for data on or after 03/26/2020 as this is the first date ICU and ventilator data are both available
data = data[data['date'] >= 20200326]
counts = data['inIcuCurrently'].tolist() + data['onVentilatorCurrently'].tolist()
count_type = (['ICU'] * len(data['inIcuCurrently'])) + (['Ventilator'] * len(data['onVentilatorCurrently']))
# If the array is not converted to a list, the all the elements in the array just get multiplied by 2
days = np.arange(0, len(data['date'])).tolist() * 2
plot_frame = pd.DataFrame({'day': days,
'counts': counts,
'count_type': count_type})
sns.set()
sns.lineplot(x='day', y='counts', hue='count_type', data=plot_frame)
plt.title('Ventilator and ICU Usage for the U.S Since 03/26/2020')
plt.xlabel('Days Since 03/26/2020')
plt.ylabel('Count')
plt.gcf().set_size_inches(10, 10)
plt.savefig(ct.plot_path + 'vent_icu_plot.png')
plt.show(block=False)
plt.close()