-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdash_predictor.py
84 lines (67 loc) · 2.43 KB
/
dash_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
# visit http://127.0.0.1:8050/ in your web browser.
import dash
import dash_core_components as dcc
import dash_html_components as html
import pandas as pd
from fbprophet import Prophet
import numpy as np
from fbprophet.plot import plot_plotly
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
#set desired region and column name with value for prediction
country = 'United States'
region = 'Washington'
prediction = 'ConfirmedCases'
prediction_length = 30
#URL for data from Oxford: updated daily
DATA_URL = 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv'
#Load data
df = pd.read_csv(DATA_URL, \
parse_dates=['Date'], \
encoding="ISO-8859-1", \
usecols = ['Date','CountryName','RegionName', prediction], \
dtype={"RegionName": str, \
"CountryName":str}, \
error_bad_lines=False)
#extract desired regional data
df = df[df['CountryName'] == country][:-1]
if region:
df = df[df['RegionName'] == region][:-1]
else:
df = df[df['Jurisdiction'] == 'NAT_TOTAL']
#prepare dataframe for Prophet
df = df.drop(columns = ['CountryName','RegionName'])
df = df.rename(columns = {'Date':'ds','ConfirmedCases':'y'})
df = df.fillna(method = 'ffill')
df = df.fillna(0)
#create Prophet instance
m = Prophet(seasonality_mode = 'multiplicative', \
yearly_seasonality = False, \
daily_seasonality = False, \
weekly_seasonality = True) \
#add holidays (only support United States)
if country == 'United States':
m.add_country_holidays(country_name='US')
#fit model and create prediction
m.fit(df)
future = m.make_future_dataframe(periods=prediction_length)
forecast = m.predict(future)
fig = plot_plotly(m, forecast, changepoints=False, \
xlabel="Date", ylabel=prediction, \
uncertainty=True, \
plot_cap=True)
fig.layout.title = {'text': f'True and Predicted {prediction} in {region}'}
fig.update_layout(showlegend=True)
app.layout = html.Div(children=[
html.H1(children='COVID Predictor'),
html.Div(children='''
COVID-19 Cumulative Case Prediction
'''),
dcc.Graph(
id='prediction_graph',
figure=fig
)
])
if __name__ == '__main__':
app.run_server(debug=True)