Skip to content

Commit

Permalink
add covid-19 dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Oct 31, 2023
1 parent bad59b2 commit 2d70e5b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
11 changes: 11 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,14 @@ def test_download(mocker, caplog):
assert "Cannot download dataset" in str(excinfo.value)
finally:
os.remove(resource_path)


def test_get_covid_19():
X, graph, states = tsgm.utils.get_covid_19()
assert len(states) == 51 and "new york" in states and "california" in states
assert len(graph[0]) == len(states) # nodes
assert len(graph[1]) == 220 # edges
assert X.shape[0] == len(states)
assert len(X.shape) == 3
assert X.shape[2] == 4
assert X.shape[1] >= 150
1 change: 1 addition & 0 deletions tsgm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from tsgm.utils.datasets import * # noqa
from tsgm.utils.utils import * # noqa
from tsgm.utils.mmd import * # noqa
from tsgm.utils.covid19_data_utils import * # noqa
45 changes: 41 additions & 4 deletions tsgm/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

from tensorflow import keras

from tsgm.utils import covid19_data_utils
from tsgm.utils import file_utils


logger = logging.getLogger('utils')
logger.setLevel(logging.DEBUG)


def gen_sine_dataset(N, T, D, max_value=10):
def gen_sine_dataset(N: int, T: int, D: int, max_value: int = 10) -> np.ndarray:
result = []
for i in range(N):
result.append([])
Expand All @@ -35,7 +36,7 @@ def gen_sine_dataset(N, T, D, max_value=10):
return np.transpose(np.array(result), [0, 2, 1])


def gen_sine_const_switch_dataset(N, T, D, max_value=10, const=0, frequency_switch=0.1):
def gen_sine_const_switch_dataset(N: int, T: int, D: int, max_value: int = 10, const: int = 0, frequency_switch: float = 0.1) -> tuple:
result_X, result_y = [], []
cur_y = 0
scales = np.random.random(D) * max_value
Expand Down Expand Up @@ -185,7 +186,7 @@ def get_mauna_loa() -> tuple:
return X, y


def split_dataset_into_objects(X, y, step=10):
def split_dataset_into_objects(X, y, step=10) -> tuple:
assert X.shape[0] == y.shape[0]

Xs, ys = [], []
Expand Down Expand Up @@ -293,7 +294,7 @@ def get_physionet2012() -> tuple:
return train_X, train_y, test_X, test_y, val_X, val_y


def download_physionet2012():
def download_physionet2012() -> None:
"""
Downloads the Physionet 2012 dataset files from the Physionet website
and extracts them in local folder 'physionet2012'
Expand Down Expand Up @@ -359,3 +360,39 @@ def _get_physionet_y_dataframe(file_path: str) -> pd.DataFrame:
y.index.name = 'recordid'
y.reset_index(inplace=True)
return y


def get_covid_19() -> tuple:
"""
Loads Covid-19 dataset with additional graph information
The dataset is based on data from The New York Times, based on reports from state and local health agencies [1].
And was adapted to graph case in [2].
[1] The New York Times. (2021). Coronavirus (Covid-19) Data in the United States. Retrieved [Insert Date Here], from https://github.com/nytimes/covid-19-data.
[2] Alexander V. Nikitin, St John, Arno Solin, Samuel Kaski Proceedings of The 25th International Conference on Artificial Intelligence and Statistics, PMLR 151:10640-10660, 2022.
Returns:
-------
tuple
First element is time series data (n_nodes x n_timestamps x n_features). Each timestamp consists of
the number of deaths, cases, deaths normalized by the population, and cases normalized by the population.
The second element is the graph tuple (nodes, edges).
The third element is the order of states.
"""
base_url = "https://raw.githubusercontent.com/nytimes/covid-19-data/master/us-states.csv"
destination_folder = "covid19"
file_utils.download(base_url, destination_folder)
result, graph = covid19_data_utils.covid_dataset(
os.path.join(destination_folder, "us-states.csv")
)

processed_dataset = []
for timestamp in result.keys():
processed_dataset.append([])
for state in covid19_data_utils.LIST_OF_STATES:
cur_data = result[timestamp][state]
processed_dataset[-1].append(
[cur_data["deaths"], cur_data["cases"],
cur_data["deaths_normalized"], cur_data["cases_normalized"]]
)
return np.transpose(np.array(processed_dataset), (1, 0, 2)), graph, covid19_data_utils.LIST_OF_STATES

0 comments on commit 2d70e5b

Please sign in to comment.