diff --git a/examples/ihme_api/cat_sex_split_example.py b/examples/ihme_api/cat_sex_split_example.py new file mode 100644 index 0000000..3629ae3 --- /dev/null +++ b/examples/ihme_api/cat_sex_split_example.py @@ -0,0 +1,174 @@ +import pandas as pd +import numpy as np + +# Import CatSplitter and configurations from your module +from pydisagg.ihme.splitter import ( + CatSplitter, + CatDataConfig, + CatPatternConfig, + CatPopulationConfig, +) + +# Set a random seed for reproducibility +np.random.seed(42) + +# ------------------------------- +# 1. Create and Update data_df +# ------------------------------- + +# Existing data_df DataFrame +data_df = pd.DataFrame( + { + "seq": [303284043, 303284062, 303284063, 303284064, 303284065], + "location_id": [78, 130, 120, 30, 141], + "mean": [0.5] * 5, + "standard_error": [0.1] * 5, + "year_id": [2015, 2019, 2018, 2017, 2016], + } +) + +# Adding the 'sex' column with a list [1, 2] for each row +data_df["sex"] = [[1, 2]] * len(data_df) + +# Sort data_df for clarity +data_df_sorted = data_df.sort_values(by=["location_id"]).reset_index(drop=True) + +# Display the sorted data_df +print("data_df:") +print(data_df_sorted) + +# ------------------------------- +# 2. Create and Update pattern_df_final +# ------------------------------- + +pattern_df = pd.DataFrame( + { + "location_id": [78, 130, 120, 30, 141], + "mean": [0.5] * 5, + "standard_error": [0.1] * 5, + "year_id": [2015, 2019, 2018, 2017, 2016], + } +) + +# Create DataFrame for sex=1 +pattern_df_sex1 = pattern_df.copy() +pattern_df_sex1["sex"] = 1 # Assign sex=1 +pattern_df_sex1["mean"] += np.random.normal(0, 0.01, size=len(pattern_df_sex1)) +pattern_df_sex1["standard_error"] += np.random.normal( + 0, 0.001, size=len(pattern_df_sex1) +) +pattern_df_sex1["mean"] = pattern_df_sex1["mean"].round(6) +pattern_df_sex1["standard_error"] = pattern_df_sex1["standard_error"].round(6) + +# Create DataFrame for sex=2 +pattern_df_sex2 = pattern_df.copy() +pattern_df_sex2["sex"] = 2 # Assign sex=2 +pattern_df_sex2["mean"] += np.random.normal(0, 0.01, size=len(pattern_df_sex2)) +pattern_df_sex2["standard_error"] += np.random.normal( + 0, 0.001, size=len(pattern_df_sex2) +) +pattern_df_sex2["mean"] = pattern_df_sex2["mean"].round(6) +pattern_df_sex2["standard_error"] = pattern_df_sex2["standard_error"].round(6) + +pattern_df_final = pd.concat( + [pattern_df_sex1, pattern_df_sex2], ignore_index=True +) + +# Sort pattern_df_final for clarity +pattern_df_final_sorted = pattern_df_final.sort_values( + by=["location_id", "sex"] +).reset_index(drop=True) + +print("\npattern_df_final:") +print(pattern_df_final_sorted) + +# ------------------------------- +# 3. Create and Update population_df +# ------------------------------- + +population_df = pd.DataFrame( + { + "location_id": [30, 30, 78, 78, 120, 120, 130, 130, 141, 141], + "year_id": [2017] * 2 + + [2015] * 2 + + [2018] * 2 + + [2019] * 2 + + [2016] * 2, + "sex": [1, 2] * 5, # Sexes 1 and 2 + "population": [ + 39789, + 40120, + 10234, + 10230, + 30245, + 29870, + 19876, + 19980, + 50234, + 49850, + ], + } +) + +# Sort population_df for clarity +population_df_sorted = population_df.sort_values( + by=["location_id", "sex"] +).reset_index(drop=True) + +# Display the sorted population_df +print("\npopulation_df:") +print(population_df_sorted) + +# ------------------------------- +# 4. Configure and Run CatSplitter +# ------------------------------- + +# Data configuration +data_config = CatDataConfig( + index=[ + "seq", + "location_id", + "year_id", + "sex", + ], # Include 'sex' in the index + cat_group="sex", + val="mean", + val_sd="standard_error", +) + +# Pattern configuration +pattern_config = CatPatternConfig( + by=["location_id", "year_id"], + cat="sex", + val="mean", + val_sd="standard_error", +) + +# Population configuration +population_config = CatPopulationConfig( + index=["location_id", "year_id", "sex"], # Include 'sex' in the index + val="population", +) + +# Initialize the CatSplitter +splitter = CatSplitter( + data=data_config, pattern=pattern_config, population=population_config +) + +# Perform the split +try: + final_split_df = splitter.split( + data=data_df, + pattern=pattern_df_final, + population=population_df, + model="rate", + output_type="rate", + ) + + # Sort the final DataFrame by 'seq' and then by 'sex' + final_split_df.sort_values(by=["seq", "sex"], inplace=True) + + print("\nFinal Split DataFrame:") + print(final_split_df) +except Exception as e: + print(f"Error during splitting: {e}") diff --git a/examples/ihme_api/cat_split.ipynb b/examples/ihme_api/cat_split.ipynb new file mode 100644 index 0000000..e328e09 --- /dev/null +++ b/examples/ihme_api/cat_split.ipynb @@ -0,0 +1,401 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pre-split DataFrame:\n", + " study_id year_id location_id mean std_err\n", + "0 8270 2010 [1234, 1235, 1236] 0.2 0.01\n", + "1 1860 2010 [2345, 2346, 2347] 0.3 0.02\n", + "2 6390 2010 [3456] 0.4 0.03\n", + "\n", + "Pattern DataFrame:\n", + " location_id year_id mean std_err\n", + "0 1234 2010 0.392798 0.048796\n", + "1 1235 2010 0.339463 0.043298\n", + "2 1236 2010 0.162407 0.018494\n", + "3 2345 2010 0.162398 0.017273\n", + "4 2346 2010 0.123233 0.017336\n", + "5 2347 2010 0.446470 0.022170\n", + "6 3456 2010 0.340446 0.030990\n", + "7 4567 2010 0.383229 0.027278\n", + "8 5678 2010 0.108234 0.021649\n", + "\n", + "Population DataFrame:\n", + " location_id year_id population\n", + "0 1234 2010 166730\n", + "1 1235 2010 880910\n", + "2 1236 2010 394681\n", + "3 2345 2010 159503\n", + "4 2346 2010 664811\n", + "5 2347 2010 537035\n", + "6 3456 2010 658143\n", + "7 4567 2010 462366\n", + "8 5678 2010 75725\n", + "\n", + "Final Split DataFrame:\n", + " mean study_id std_err location_id year_id cat_pat_mean \\\n", + "3 0.3 1860 0.02 2345 2010 0.162398 \n", + "4 0.3 1860 0.02 2346 2010 0.123233 \n", + "5 0.3 1860 0.02 2347 2010 0.446470 \n", + "6 0.4 6390 0.03 3456 2010 0.340446 \n", + "0 0.2 8270 0.01 1234 2010 0.392798 \n", + "1 0.2 8270 0.01 1235 2010 0.339463 \n", + "2 0.2 8270 0.01 1236 2010 0.162407 \n", + "\n", + " cat_pat_std_err population split_result split_result_se split_flag \\\n", + "3 0.017273 159503.0 0.190806 0.024440 1 \n", + "4 0.017336 664811.0 0.144790 0.019012 1 \n", + "5 0.022170 537035.0 0.524570 0.040101 1 \n", + "6 0.030990 658143.0 0.400000 0.030000 0 \n", + "0 0.048796 166730.0 0.264351 0.039018 1 \n", + "1 0.043298 880910.0 0.228457 0.015557 1 \n", + "2 0.018494 394681.0 0.109300 0.015518 1 \n", + "\n", + " orig_group \n", + "3 [2345, 2346, 2347] \n", + "4 [2345, 2346, 2347] \n", + "5 [2345, 2346, 2347] \n", + "6 [3456] \n", + "0 [1234, 1235, 1236] \n", + "1 [1234, 1235, 1236] \n", + "2 [1234, 1235, 1236] \n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "# Assuming the CatSplitter and configuration classes have been imported correctly\n", + "from pydisagg.ihme.splitter import (\n", + " CatSplitter,\n", + " CatDataConfig,\n", + " CatPatternConfig,\n", + " CatPopulationConfig,\n", + ")\n", + "\n", + "# Set a random seed for reproducibility\n", + "np.random.seed(42)\n", + "\n", + "# -------------------------------\n", + "# Example DataFrames\n", + "# -------------------------------\n", + "\n", + "# Pre-split DataFrame with 3 rows\n", + "pre_split = pd.DataFrame(\n", + " {\n", + " \"study_id\": np.random.randint(1000, 9999, size=3), # Unique study IDs\n", + " \"year_id\": [2010, 2010, 2010],\n", + " \"location_id\": [\n", + " [1234, 1235, 1236], # List of location_ids for row 1\n", + " [2345, 2346, 2347], # List of location_ids for row 2\n", + " [3456], # Single location_id for row 3 (no need to split)\n", + " ],\n", + " \"mean\": [0.2, 0.3, 0.4],\n", + " \"std_err\": [0.01, 0.02, 0.03],\n", + " }\n", + ")\n", + "\n", + "# Create a list of all location_ids mentioned\n", + "all_location_ids = [\n", + " 1234,\n", + " 1235,\n", + " 1236,\n", + " 2345,\n", + " 2346,\n", + " 2347,\n", + " 3456,\n", + " 4567, # Additional location_ids\n", + " 5678,\n", + "]\n", + "\n", + "# Pattern DataFrame for all location_ids\n", + "data_pattern = pd.DataFrame(\n", + " {\n", + " \"location_id\": all_location_ids,\n", + " \"year_id\": [2010] * len(all_location_ids),\n", + " \"mean\": np.random.uniform(0.1, 0.5, len(all_location_ids)),\n", + " \"std_err\": np.random.uniform(0.01, 0.05, len(all_location_ids)),\n", + " }\n", + ")\n", + "\n", + "# Population DataFrame for all location_ids\n", + "data_pop = pd.DataFrame(\n", + " {\n", + " \"location_id\": all_location_ids,\n", + " \"year_id\": [2010] * len(all_location_ids),\n", + " \"population\": np.random.randint(10000, 1000000, len(all_location_ids)),\n", + " }\n", + ")\n", + "\n", + "# Print the DataFrames\n", + "print(\"Pre-split DataFrame:\")\n", + "print(pre_split)\n", + "print(\"\\nPattern DataFrame:\")\n", + "print(data_pattern)\n", + "print(\"\\nPopulation DataFrame:\")\n", + "print(data_pop)\n", + "\n", + "# -------------------------------\n", + "# Configurations\n", + "# -------------------------------\n", + "\n", + "data_config = CatDataConfig(\n", + " index=[\"study_id\", \"year_id\"], # Include study_id in the index\n", + " target=\"location_id\", # Column containing list of targets\n", + " val=\"mean\",\n", + " val_sd=\"std_err\",\n", + ")\n", + "\n", + "pattern_config = CatPatternConfig(\n", + " index=[\"year_id\"],\n", + " target=\"location_id\",\n", + " val=\"mean\",\n", + " val_sd=\"std_err\",\n", + ")\n", + "\n", + "population_config = CatPopulationConfig(\n", + " index=[\"year_id\"],\n", + " target=\"location_id\",\n", + " val=\"population\",\n", + ")\n", + "\n", + "# Initialize the CatSplitter\n", + "splitter = CatSplitter(\n", + " data=data_config, pattern=pattern_config, population=population_config\n", + ")\n", + "\n", + "# Perform the split\n", + "try:\n", + " final_split_df = splitter.split(\n", + " data=pre_split,\n", + " pattern=data_pattern,\n", + " population=data_pop,\n", + " model=\"rate\",\n", + " output_type=\"rate\",\n", + " )\n", + " final_split_df.sort_values(by=[\"study_id\", \"location_id\"], inplace=True)\n", + " print(\"\\nFinal Split DataFrame:\")\n", + " print(final_split_df)\n", + "except ValueError as e:\n", + " print(f\"Error: {e}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
meanstudy_idstd_errlocation_idyear_idcat_pat_meancat_pat_std_errpopulationsplit_resultsplit_result_sesplit_flagorig_group
30.318600.02234520100.1623980.017273159503.00.1908060.0244401[2345, 2346, 2347]
40.318600.02234620100.1232330.017336664811.00.1447900.0190121[2345, 2346, 2347]
50.318600.02234720100.4464700.022170537035.00.5245700.0401011[2345, 2346, 2347]
60.463900.03345620100.3404460.030990658143.00.4000000.0300000[3456]
00.282700.01123420100.3927980.048796166730.00.2643510.0390181[1234, 1235, 1236]
10.282700.01123520100.3394630.043298880910.00.2284570.0155571[1234, 1235, 1236]
20.282700.01123620100.1624070.018494394681.00.1093000.0155181[1234, 1235, 1236]
\n", + "
" + ], + "text/plain": [ + " mean study_id std_err location_id year_id cat_pat_mean \\\n", + "3 0.3 1860 0.02 2345 2010 0.162398 \n", + "4 0.3 1860 0.02 2346 2010 0.123233 \n", + "5 0.3 1860 0.02 2347 2010 0.446470 \n", + "6 0.4 6390 0.03 3456 2010 0.340446 \n", + "0 0.2 8270 0.01 1234 2010 0.392798 \n", + "1 0.2 8270 0.01 1235 2010 0.339463 \n", + "2 0.2 8270 0.01 1236 2010 0.162407 \n", + "\n", + " cat_pat_std_err population split_result split_result_se split_flag \\\n", + "3 0.017273 159503.0 0.190806 0.024440 1 \n", + "4 0.017336 664811.0 0.144790 0.019012 1 \n", + "5 0.022170 537035.0 0.524570 0.040101 1 \n", + "6 0.030990 658143.0 0.400000 0.030000 0 \n", + "0 0.048796 166730.0 0.264351 0.039018 1 \n", + "1 0.043298 880910.0 0.228457 0.015557 1 \n", + "2 0.018494 394681.0 0.109300 0.015518 1 \n", + "\n", + " orig_group \n", + "3 [2345, 2346, 2347] \n", + "4 [2345, 2346, 2347] \n", + "5 [2345, 2346, 2347] \n", + "6 [3456] \n", + "0 [1234, 1235, 1236] \n", + "1 [1234, 1235, 1236] \n", + "2 [1234, 1235, 1236] " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "final_split_df" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pyDis-mac", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/ihme_api/cat_split_example.py b/examples/ihme_api/cat_split_example.py new file mode 100644 index 0000000..94dbb6e --- /dev/null +++ b/examples/ihme_api/cat_split_example.py @@ -0,0 +1,126 @@ +# cat_split_example.py + +import numpy as np +import pandas as pd +from pandas import DataFrame + +# Import CatSplitter and related classes +from pydisagg.ihme.splitter import ( + CatSplitter, + CatDataConfig, + CatPatternConfig, + CatPopulationConfig, +) + +# ------------------------------- +# Example DataFrames +# ------------------------------- + +# Set a random seed for reproducibility +np.random.seed(42) + +# Pre-split DataFrame with 3 rows +pre_split = pd.DataFrame( + { + "study_id": np.random.randint(1000, 9999, size=3), # Unique study IDs + "year_id": [2010, 2010, 2010], + "location_id": [ + [1234, 1235, 1236], # List of location_ids for row 1 + [2345, 2346, 2347], # List of location_ids for row 2 + [3456], # Single location_id for row 3 (no need to split) + ], + "mean": [0.2, 0.3, 0.4], + "std_err": [0.01, 0.02, 0.03], + } +) + +# Create a list of all location_ids mentioned +all_location_ids = [ + 1234, + 1235, + 1236, + 2345, + 2346, + 2347, + 3456, + 4567, + 5678, # Additional location_ids +] + +# Pattern DataFrame for all location_ids +data_pattern = pd.DataFrame( + { + "year_id": [2010] * len(all_location_ids), + "location_id": all_location_ids, + "mean": np.random.uniform(0.1, 0.5, len(all_location_ids)), + "std_err": np.random.uniform(0.01, 0.05, len(all_location_ids)), + } +) + +# Population DataFrame for all location_ids +data_pop = pd.DataFrame( + { + "year_id": [2010] * len(all_location_ids), + "location_id": all_location_ids, + "population": np.random.randint(10000, 1000000, len(all_location_ids)), + } +) + +# Print the DataFrames +print("Pre-split DataFrame:") +print(pre_split) +print("\nPattern DataFrame:") +print(data_pattern) +print("\nPopulation DataFrame:") +print(data_pop) + +# ------------------------------- +# Configurations +# ------------------------------- + +# Adjusted configurations to match the modified CatSplitter +data_config = CatDataConfig( + index=[ + "study_id", + "year_id", + "location_id", + ], # Include 'location_id' in the index + cat_group="location_id", + val="mean", + val_sd="std_err", +) + +pattern_config = CatPatternConfig( + by=["year_id"], + cat="location_id", + val="mean", + val_sd="std_err", +) + +population_config = CatPopulationConfig( + index=["year_id", "location_id"], + val="population", +) + +# Initialize the CatSplitter with the updated configurations +splitter = CatSplitter( + data=data_config, + pattern=pattern_config, + population=population_config, +) + +# Perform the split +try: + final_split_df = splitter.split( + data=pre_split, + pattern=data_pattern, + population=data_pop, + model="rate", + output_type="rate", + ) + # Sort the final DataFrame for better readability + final_split_df.sort_values(by=["study_id", "location_id"], inplace=True) + print("\nFinal Split DataFrame:") + print(final_split_df) +except Exception as e: + print(f"Error during splitting: {e}") diff --git a/pyproject.toml b/pyproject.toml index bc768b0..90dec5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ [project] name = "pydisagg" -version = "0.5.2" +version = "0.6.0" description = "" readme = "README.md" license = { text = "BSD 2-Clause License" } diff --git a/src/pydisagg/ihme/splitter/__init__.py b/src/pydisagg/ihme/splitter/__init__.py index 2687c09..c954cdf 100644 --- a/src/pydisagg/ihme/splitter/__init__.py +++ b/src/pydisagg/ihme/splitter/__init__.py @@ -10,6 +10,12 @@ SexPatternConfig, SexPopulationConfig, ) +from .cat_splitter import ( + CatSplitter, + CatDataConfig, + CatPatternConfig, + CatPopulationConfig, +) __all__ = [ "AgeSplitter", @@ -20,4 +26,8 @@ "SexDataConfig", "SexPatternConfig", "SexPopulationConfig", + "CatSplitter", + "CatDataConfig", + "CatPatternConfig", + "CatPopulationConfig", ] diff --git a/src/pydisagg/ihme/splitter/age_splitter.py b/src/pydisagg/ihme/splitter/age_splitter.py index bd389dc..e0f255e 100644 --- a/src/pydisagg/ihme/splitter/age_splitter.py +++ b/src/pydisagg/ihme/splitter/age_splitter.py @@ -197,7 +197,7 @@ def parse_pattern( def _merge_with_pattern( self, data: DataFrame, pattern: DataFrame ) -> DataFrame: - # Ensure the necessary columns are present before merging + # TODO change these asserts to validate_columns assert ( self.data.age_lwr in data.columns ), f"Column '{self.data.age_lwr}' not found in data" @@ -236,11 +236,10 @@ def parse_population( validate_index(population, self.population.index, name) validate_nonan(population, name) - pop_copy = population.copy() rename_map = self.population.apply_prefix() - pop_copy.rename(columns=rename_map, inplace=True) + population.rename(columns=rename_map, inplace=True) - data_with_population = self._merge_with_population(data, pop_copy) + data_with_population = self._merge_with_population(data, population) validate_noindexdiff( data, diff --git a/src/pydisagg/ihme/splitter/cat_splitter.py b/src/pydisagg/ihme/splitter/cat_splitter.py new file mode 100644 index 0000000..2ee7545 --- /dev/null +++ b/src/pydisagg/ihme/splitter/cat_splitter.py @@ -0,0 +1,330 @@ +# cat_splitter.py + +from typing import Any, List, Literal +import numpy as np +import pandas as pd +from pandas import DataFrame +from pydantic import BaseModel + +from pydisagg.disaggregate import split_datapoint +from pydisagg.models import RateMultiplicativeModel, LogOddsModel +from pydisagg.ihme.schema import Schema +from pydisagg.ihme.validator import ( + validate_columns, + validate_index, + validate_noindexdiff, + validate_nonan, + validate_positive, + validate_set_uniqueness, +) + + +class CatDataConfig(Schema): + """ + Configuration schema for categorical data DataFrame. + """ + + index: List[str] + cat_group: str + val: str + val_sd: str + + @property + def columns(self) -> List[str]: + return self.index + [self.cat_group, self.val, self.val_sd] + + @property + def val_fields(self) -> List[str]: + return [self.val, self.val_sd] + + +class CatPatternConfig(Schema): + """ + Configuration schema for the pattern DataFrame. + """ + + by: List[str] + cat: str + draws: List[str] = [] + val: str = "mean" + val_sd: str = "std_err" + prefix: str = "cat_pat_" + + @property + def index(self) -> List[str]: + return self.by + [self.cat] + + @property + def columns(self) -> List[str]: + return self.index + self.val_fields + self.draws + + @property + def val_fields(self) -> List[str]: + return [self.val, self.val_sd] + + +class CatPopulationConfig(Schema): + """ + Configuration for the population DataFrame. + """ + + index: List[str] + val: str + prefix: str = "cat_pop_" + + @property + def columns(self) -> List[str]: + return self.index + [self.val] + + @property + def val_fields(self) -> List[str]: + return [self.val] + + +class CatSplitter(BaseModel): + """ + Class for splitting categorical data based on pattern and population data. + """ + + data: CatDataConfig + pattern: CatPatternConfig + population: CatPopulationConfig + + def model_post_init(self, __context: Any) -> None: + """ + Perform extra validation after model initialization. + """ + if not set(self.pattern.index).issubset(self.data.index): + raise ValueError( + "The pattern's match criteria must be a subset of the data." + ) + if not set(self.population.index).issubset( + self.data.index + self.pattern.index + ): + raise ValueError( + "The population's match criteria must be a subset of the data and the pattern." + ) + if self.pattern.cat not in self.population.index: + raise ValueError( + "The 'target' column in the population must match the 'target' column in the data." + ) + + def parse_data(self, data: DataFrame, positive_strict: bool) -> DataFrame: + """ + Parse and validate the input data DataFrame. + """ + name = "While parsing data" + + # Validate required columns + validate_columns(data, self.data.columns, name) + + # Ensure that 'cat_group' column contains lists + data[self.data.cat_group] = data[self.data.cat_group].apply( + lambda x: x if isinstance(x, list) else [x] + ) + + # Validate that every list in 'cat_group' contains unique elements + validate_set_uniqueness(data, self.data.cat_group, name) + + # Explode the 'cat_group' column and rename it to match the pattern's 'cat' + data = data.explode(self.data.cat_group).rename( + columns={self.data.cat_group: self.pattern.cat} + ) + + # Validate index after exploding + validate_index(data, self.data.index, name) + validate_nonan(data, name) + validate_positive(data, [self.data.val_sd], name, strict=positive_strict) + + return data + + def _merge_with_pattern( + self, + data: DataFrame, + pattern: DataFrame, + ) -> DataFrame: + """ + Merge data with pattern DataFrame. + """ + data_with_pattern = data.merge(pattern, on=self.pattern.index, how="left") + + validate_nonan( + data_with_pattern[ + [f"{self.pattern.prefix}{col}" for col in self.pattern.val_fields] + ], + "After merging with pattern, there were NaN values created. This indicates that your pattern does not cover all the data.", + ) + + return data_with_pattern + + def parse_pattern( + self, data: DataFrame, pattern: DataFrame, model: str + ) -> DataFrame: + """ + Parse and merge the pattern DataFrame with data. + """ + name = "While parsing pattern" + + try: + val_cols = self.pattern.val_fields + if not all(col in pattern.columns for col in val_cols): + if not self.pattern.draws: + raise ValueError( + f"{name}: Must provide draws for pattern if pattern.val and " + "pattern.val_sd are not available." + ) + validate_columns(pattern, self.pattern.draws, name) + pattern[self.pattern.val] = pattern[self.pattern.draws].mean(axis=1) + pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(axis=1) + + validate_columns(pattern, self.pattern.columns, name) + except KeyError as e: + raise KeyError(f"{name}: Missing columns in the pattern. Details:\n{e}") + + pattern_copy = pattern.copy() + pattern_copy = pattern_copy[self.pattern.index + self.pattern.val_fields] + rename_map = { + col: f"{self.pattern.prefix}{col}" for col in self.pattern.val_fields + } + pattern_copy.rename(columns=rename_map, inplace=True) + + # Merge with pattern + data_with_pattern = self._merge_with_pattern(data, pattern_copy) + + # Validate index differences after merging + validate_noindexdiff( + data, + data_with_pattern, + self.data.index, + name, + ) + + return data_with_pattern + + def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame: + name = "Parsing Population" + validate_columns(population, self.population.columns, name) + + population = population[self.population.columns].copy() + + validate_index(population, self.population.index, name) + validate_nonan(population, name) + + rename_map = { + self.population.val: f"{self.population.prefix}{self.population.val}" + } + population.rename(columns=rename_map, inplace=True) + + data_with_population = self._merge_with_population(data, population) + + # Ensure the prefixed population column exists + pop_col = f"{self.population.prefix}{self.population.val}" + if pop_col not in data_with_population.columns: + raise KeyError(f"Expected column '{pop_col}' not found in merged data.") + + validate_nonan( + data_with_population[[pop_col]], + "After merging with population, there were NaN values created. This indicates that your population data does not cover all the data.", + ) + return data_with_population + + def _merge_with_population( + self, data: DataFrame, population: DataFrame + ) -> DataFrame: + """ + Merge data with population DataFrame. + """ + data_with_population = data.merge( + population, on=self.population.index, how="left" + ) + + return data_with_population + + def _process_group( + self, group: DataFrame, model: str, output_type: str + ) -> DataFrame: + """ + Process a group of data for splitting. + """ + observed_total = group[self.data.val].iloc[0] + observed_total_se = group[self.data.val_sd].iloc[0] + + if len(group) == 1: + # No need to split, assign the observed values + group["split_result"] = observed_total + group["split_result_se"] = observed_total_se + group["split_flag"] = 0 # Not split + else: + # Need to split among multiple targets + bucket_populations = group[ + f"{self.population.prefix}{self.population.val}" + ].values + rate_pattern = group[f"{self.pattern.prefix}{self.pattern.val}"].values + pattern_sd = group[f"{self.pattern.prefix}{self.pattern.val_sd}"].values + pattern_covariance = np.diag(pattern_sd**2) + + if model == "rate": + splitting_model = RateMultiplicativeModel() + elif model == "logodds": + splitting_model = LogOddsModel() + + # Determine whether to normalize by population for the output type + pop_normalize = output_type == "rate" + + # Perform splitting + split_result, split_se = split_datapoint( + observed_total=observed_total, + bucket_populations=bucket_populations, + rate_pattern=rate_pattern, + model=splitting_model, + output_type=output_type, + normalize_pop_for_average_type_obs=pop_normalize, + observed_total_se=observed_total_se, + pattern_covariance=pattern_covariance, + ) + + # Assign results back to the group + group["split_result"] = split_result + group["split_result_se"] = split_se + group["split_flag"] = 1 # Split + + return group + + def split( + self, + data: DataFrame, + pattern: DataFrame, + population: DataFrame, + model: Literal["rate", "logodds"] = "rate", + output_type: Literal["rate", "count"] = "rate", + ) -> DataFrame: + """ + Split the input data based on a specified pattern and population model. + """ + # Validate model and output_type + if model not in ["rate", "logodds"]: + raise ValueError(f"Invalid model: {model}") + if output_type not in ["rate", "count"]: + raise ValueError(f"Invalid output_type: {output_type}") + + if self.population.prefix_status == "prefixed": + self.population.remove_prefix() + if self.pattern.prefix_status == "prefixed": + self.pattern.remove_prefix() + + # Parsing input data, pattern, and population + data = self.parse_data(data, positive_strict=True) + data = self.parse_pattern(data, pattern, model) + data = self.parse_population(data, population) + + # Determine grouping columns + group_cols = self.data.index[:-1] # Exclude 'location_id' from grouping + + # Process groups using regular groupby + final_split_df = ( + data.groupby(group_cols, group_keys=False) + .apply(lambda group: self._process_group(group, model, output_type)) + .reset_index(drop=True) + ) + + return final_split_df diff --git a/src/pydisagg/ihme/splitter/sex_splitter.py b/src/pydisagg/ihme/splitter/sex_splitter.py index 4ebd359..b5008ee 100644 --- a/src/pydisagg/ihme/splitter/sex_splitter.py +++ b/src/pydisagg/ihme/splitter/sex_splitter.py @@ -89,8 +89,12 @@ def model_post_init(self, __context: Any) -> None: "population.index must be a subset of data.index + pattern.index" ) - def _merge_with_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame: - data_with_pattern = data.merge(pattern, on=self.pattern.by, how="left").dropna() + def _merge_with_pattern( + self, data: DataFrame, pattern: DataFrame + ) -> DataFrame: + data_with_pattern = data.merge( + pattern, on=self.pattern.by, how="left" + ).dropna() return data_with_pattern def get_population_by_sex(self, population, sex_value): @@ -105,7 +109,9 @@ def parse_data(self, data: DataFrame) -> DataFrame: try: validate_columns(data, self.data.columns, name) except KeyError as e: - raise KeyError(f"{name}: Missing columns in the input data. Details:\n{e}") + raise KeyError( + f"{name}: Missing columns in the input data. Details:\n{e}" + ) if self.population.sex not in data.columns: raise KeyError( @@ -147,12 +153,18 @@ def parse_pattern( "pattern.val_sd are not available." ) validate_columns(pattern, self.pattern.draws, name) - pattern[self.pattern.val] = pattern[self.pattern.draws].mean(axis=1) - pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(axis=1) + pattern[self.pattern.val] = pattern[self.pattern.draws].mean( + axis=1 + ) + pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std( + axis=1 + ) validate_columns(pattern, self.pattern.columns, name) except KeyError as e: - raise KeyError(f"{name}: Missing columns in the pattern. Details:\n{e}") + raise KeyError( + f"{name}: Missing columns in the pattern. Details:\n{e}" + ) pattern = pattern[self.pattern.columns].copy() @@ -166,7 +178,9 @@ def parse_pattern( try: validate_nonan(pattern, name) except ValueError as e: - raise ValueError(f"{name}: NaN values found in the pattern. Details:\n{e}") + raise ValueError( + f"{name}: NaN values found in the pattern. Details:\n{e}" + ) if model == "rate": try: @@ -196,7 +210,9 @@ def parse_pattern( return data_with_pattern - def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame: + def parse_population( + self, data: DataFrame, population: DataFrame + ) -> DataFrame: name = "While parsing population" # Step 1: Validate population columns @@ -208,13 +224,19 @@ def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame: ) # Step 2: Get male and female populations and rename columns - male_population = self.get_population_by_sex(population, self.population.sex_m) + male_population = self.get_population_by_sex( + population, self.population.sex_m + ) female_population = self.get_population_by_sex( population, self.population.sex_f ) - male_population.rename(columns={self.population.val: "m_pop"}, inplace=True) - female_population.rename(columns={self.population.val: "f_pop"}, inplace=True) + male_population.rename( + columns={self.population.val: "m_pop"}, inplace=True + ) + female_population.rename( + columns={self.population.val: "f_pop"}, inplace=True + ) # Step 3: Merge population data with main data data_with_population = self._merge_with_population( @@ -242,7 +264,9 @@ def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame: # Step 6: Validate index differences try: - validate_noindexdiff(data, data_with_population, self.data.index, name) + validate_noindexdiff( + data, data_with_population, self.data.index, name + ) except ValueError as e: raise ValueError( f"{name}: Index differences found between data and population. Details:\n{e}" @@ -250,8 +274,12 @@ def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame: # Ensure the columns are in the correct numeric type (e.g., float64) # Convert "m_pop" and "f_pop" columns to standard numeric types if necessary - data_with_population["m_pop"] = data_with_population["m_pop"].astype("float64") - data_with_population["f_pop"] = data_with_population["f_pop"].astype("float64") + data_with_population["m_pop"] = data_with_population["m_pop"].astype( + "float64" + ) + data_with_population["f_pop"] = data_with_population["f_pop"].astype( + "float64" + ) return data_with_population @@ -268,9 +296,9 @@ def _merge_with_population( # Ensure the merged population columns are standard numeric types if pop_col in data_with_population.columns: - data_with_population[pop_col] = data_with_population[pop_col].astype( - "float64" - ) + data_with_population[pop_col] = data_with_population[ + pop_col + ].astype("float64") return data_with_population @@ -376,12 +404,16 @@ def split( lambda row: split_datapoint( observed_total=row[self.data.val], bucket_populations=np.array([row["m_pop"], row["f_pop"]]), - rate_pattern=input_patterns[split_data.index.get_loc(row.name)], + rate_pattern=input_patterns[ + split_data.index.get_loc(row.name) + ], model=splitting_model, output_type=output_type, normalize_pop_for_average_type_obs=pop_normalize, observed_total_se=row[self.data.val_sd], - pattern_covariance=np.diag([0, row[self.pattern.val_sd] ** 2]), + pattern_covariance=np.diag( + [0, row[self.pattern.val_sd] ** 2] + ), ), axis=1, ) @@ -412,7 +444,11 @@ def split( # Reindex columns final_split_df = final_split_df.reindex( columns=self.data.index - + [col for col in final_split_df.columns if col not in self.data.index] + + [ + col + for col in final_split_df.columns + if col not in self.data.index + ] ) # Clean up any prefixes added earlier diff --git a/src/pydisagg/ihme/validator.py b/src/pydisagg/ihme/validator.py index 1c86379..550b253 100644 --- a/src/pydisagg/ihme/validator.py +++ b/src/pydisagg/ihme/validator.py @@ -4,6 +4,23 @@ def validate_columns(df: DataFrame, columns: list[str], name: str) -> None: + """ + Validates that all specified columns are present in the DataFrame. + + Parameters + ---------- + df : pandas.DataFrame + The DataFrame to validate. + columns : list of str + A list of expected column names that should be present in the DataFrame. + name : str + A name for the DataFrame, used in error messages. + + Raises + ------ + KeyError + If any of the specified columns are missing from the DataFrame. + """ missing = [col for col in columns if col not in df.columns] if missing: error_message = ( @@ -18,11 +35,30 @@ def validate_columns(df: DataFrame, columns: list[str], name: str) -> None: def validate_index(df: DataFrame, index: list[str], name: str) -> None: + """ + Validates that the DataFrame does not contain duplicate indices based on specified columns. + + Parameters + ---------- + df : pandas.DataFrame + The DataFrame to validate. + index : list of str + A list of column names to be used as the index for validation. + name : str + A name for the DataFrame, used in error messages. + + Raises + ------ + ValueError + If duplicate indices are found in the DataFrame based on the specified columns. + """ duplicated_index = pd.MultiIndex.from_frame( df[df[index].duplicated()][index] ).to_list() if duplicated_index: - error_message = f"{name} has duplicated index with {len(duplicated_index)} indices \n" + error_message = ( + f"{name} has duplicated index with {len(duplicated_index)} indices \n" + ) error_message += f"Index columns: ({', '.join(index)})\n" if len(duplicated_index) > 5: error_message += "First 5: \n" @@ -32,11 +68,24 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None: def validate_nonan(df: DataFrame, name: str) -> None: + """ + Validates that the DataFrame does not contain any NaN values. + + Parameters + ---------- + df : pandas.DataFrame + The DataFrame to validate. + name : str + A name for the DataFrame, used in error messages. + + Raises + ------ + ValueError + If any NaN values are found in the DataFrame. + """ nan_columns = df.columns[df.isna().any(axis=0)].to_list() if nan_columns: - error_message = ( - f"{name} has NaN values in {len(nan_columns)} columns. \n" - ) + error_message = f"{name} has NaN values in {len(nan_columns)} columns. \n" error_message += f"Columns with NaN values: {', '.join(nan_columns)}\n" if len(nan_columns) > 5: error_message += "First 5 columns with NaN values: \n" @@ -48,7 +97,27 @@ def validate_nonan(df: DataFrame, name: str) -> None: def validate_positive( df: DataFrame, columns: list[str], name: str, strict: bool = False ) -> None: - """Validates that observation values in cols are non-negative or strictly positive""" + """ + Validates that specified columns contain non-negative or strictly positive values. + + Parameters + ---------- + df : pandas.DataFrame + The DataFrame to validate. + columns : list of str + A list of column names to check for positive values. + name : str + A name for the DataFrame, used in error messages. + strict : bool, optional + If True, checks that values are strictly greater than zero. + If False, checks that values are greater than or equal to zero. + Default is False. + + Raises + ------ + ValueError + If any of the specified columns contain invalid (negative or zero) values. + """ op = "<=" if strict else "<" negative = [col for col in columns if df.eval(f"{col} {op} 0").any()] if negative: @@ -59,11 +128,32 @@ def validate_positive( def validate_interval( df: DataFrame, lwr: str, upr: str, index: list[str], name: str ) -> None: + """ + Validates that lower interval bounds are strictly less than upper bounds. + + Parameters + ---------- + df : pandas.DataFrame + The DataFrame containing interval data to validate. + lwr : str + The name of the column representing the lower bound of the interval. + upr : str + The name of the column representing the upper bound of the interval. + index : list of str + A list of column names to be used as the index for identifying intervals. + name : str + A name for the DataFrame, used in error messages. + + Raises + ------ + ValueError + If any lower bound is not strictly less than its corresponding upper bound. + """ invalid_index = pd.MultiIndex.from_frame( df.query(f"{lwr} >= {upr}")[index] ).to_list() if invalid_index: - error_message = f"{name} has invalid interval with {len(invalid_index)} indices. \nLower age must be strictly less than upper age.\n" + error_message = f"{name} has invalid interval with {len(invalid_index)} indices. \nLower bound must be strictly less than upper bound.\n" error_message += f"Index columns: ({', '.join(index)})\n" if len(invalid_index) > 5: error_message += "First 5 indices with invalid interval: \n" @@ -75,17 +165,34 @@ def validate_interval( def validate_noindexdiff( df_ref: DataFrame, df: DataFrame, index: list[str], name: str ) -> None: + """ + Validates that the indices of two DataFrames match. + + Parameters + ---------- + df_ref : pandas.DataFrame + The reference DataFrame containing the expected indices. + df : pandas.DataFrame + The DataFrame to validate against the reference. + index : list of str + A list of column names to be used as the index for comparison. + name : str + A name for the validation context, used in error messages. + + Raises + ------ + ValueError + If there are indices in the reference DataFrame that are missing in the DataFrame to validate. + """ index_ref = pd.MultiIndex.from_frame(df_ref[index]) - index = pd.MultiIndex.from_frame(df[index]) - missing_index = index_ref.difference(index).to_list() + index_to_check = pd.MultiIndex.from_frame(df[index]) + missing_index = index_ref.difference(index_to_check).to_list() if missing_index: - error_message = ( - f"Missing {name} info for {len(missing_index)} indices \n" - ) - error_message += f"Index columns: ({', '.join(index.names)})\n" + error_message = f"Missing {name} info for {len(missing_index)} indices \n" + error_message += f"Index columns: ({', '.join(index_ref.names)})\n" if len(missing_index) > 5: - error_message += "First 5: \n" + error_message += "First 5 missing indices: \n" error_message += ", \n".join(str(idx) for idx in missing_index[:5]) error_message += "\n" raise ValueError(error_message) @@ -100,16 +207,36 @@ def validate_pat_coverage( index: list[str], name: str, ) -> None: - """Validation checks for incomplete age pattern - * pattern age intervals do not overlap or have gaps - * smallest pattern interval doesn't cover the left end point of data - * largest pattern interval doesn't cover the right end point of data """ - # sort dataframe + Validates that the pattern intervals cover the data intervals completely without gaps or overlaps. + + Parameters + ---------- + df : pandas.DataFrame + The DataFrame containing both data intervals and pattern intervals. + lwr : str + The name of the column representing the data's lower bound. + upr : str + The name of the column representing the data's upper bound. + pat_lwr : str + The name of the column representing the pattern's lower bound. + pat_upr : str + The name of the column representing the pattern's upper bound. + index : list of str + A list of column names to group by when validating intervals. + name : str + A name for the DataFrame or validation context, used in error messages. + + Raises + ------ + ValueError + If the pattern intervals have gaps or overlaps, or if they do not fully cover the data intervals. + """ + # Sort dataframe df = df.sort_values(index + [lwr, upr, pat_lwr, pat_upr], ignore_index=True) df_group = df.groupby(index) - # check overlap or gap in pattern + # Check overlap or gap in pattern shifted_pat_upr = df_group[pat_upr].shift(1) connect_index = shifted_pat_upr.notnull() connected = np.allclose( @@ -122,7 +249,7 @@ def validate_pat_coverage( "bounds across categories." ) - # check coverage of head and tail + # Check coverage of head and tail head_covered = df_group.first().eval(f"{lwr} >= {pat_lwr}").all() tail_covered = df_group.last().eval(f"{upr} <= {pat_upr}").all() @@ -134,31 +261,61 @@ def validate_pat_coverage( def validate_realnumber(df: DataFrame, columns: list[str], name: str) -> None: """ - Validates that observation values in columns are real numbers and non-zero. + Validates that specified columns contain real numbers, are non-zero, and are not NaN or Inf. Parameters ---------- - df : DataFrame + df : pandas.DataFrame The DataFrame containing the data to validate. columns : list of str A list of column names to validate within the DataFrame. name : str - A string representing the name of the data or dataset - (used for constructing error messages). + A name for the DataFrame, used in error messages. Raises ------ ValueError - If any column contains values that are not real numbers or are zero. + If any column contains values that are not real numbers, are zero, or are NaN/Inf. """ - # Check for non-real or zero values in the specified columns - invalid = [ - col - for col in columns - if not df[col] - .apply(lambda x: isinstance(x, (int, float)) and x != 0) - .all() - ] + # Check for non-real, zero, NaN, or Inf values in the specified columns + invalid = [] + for col in columns: + if ( + not df[col] + .apply( + lambda x: isinstance(x, (int, float)) + and x != 0 + and pd.notna(x) + and np.isfinite(x) + ) + .all() + ): + invalid.append(col) if invalid: raise ValueError(f"{name} has non-real or zero values in: {invalid}") + + +def validate_set_uniqueness(df: DataFrame, column: str, name: str) -> None: + """ + Validates that each list in the specified column contains unique elements. + + Parameters + ---------- + df : pandas.DataFrame + The DataFrame containing the data to validate. + column : str + The name of the column containing lists to validate. + name : str + A name for the DataFrame or validation context, used in error messages. + + Raises + ------ + ValueError + If any list in the specified column contains duplicate elements. + """ + invalid_rows = df[df[column].apply(lambda x: len(x) != len(set(x)))] + if not invalid_rows.empty: + error_message = f"{name} has rows in column '{column}' where list elements are not unique.\n" + error_message += f"Indices of problematic rows: {invalid_rows.index.tolist()}\n" + raise ValueError(error_message) diff --git a/tests/test_cat_splitter.py b/tests/test_cat_splitter.py new file mode 100644 index 0000000..a2228f5 --- /dev/null +++ b/tests/test_cat_splitter.py @@ -0,0 +1,245 @@ +import pytest +import pandas as pd +from pydisagg.ihme.splitter import ( + CatSplitter, + CatDataConfig, + CatPatternConfig, + CatPopulationConfig, +) + +# Step 1: Setup Fixtures + + +@pytest.fixture +def cat_data_config(): + return CatDataConfig( + index=[ + "study_id", + "year_id", + "location_id", + "sub_category", + ], # Include 'sub_category' in index + cat_group="sub_category", + val="val", + val_sd="val_sd", + ) + + +@pytest.fixture +def cat_pattern_config(): + return CatPatternConfig( + by=["year_id", "location_id"], + cat="sub_category", + val="pattern_val", + val_sd="pattern_val_sd", + ) + + +@pytest.fixture +def cat_population_config(): + return CatPopulationConfig( + index=["year_id", "location_id", "sub_category"], + val="population", + ) + + +@pytest.fixture +def valid_data(): + return pd.DataFrame( + { + "study_id": [1, 2, 3], + "year_id": [2000, 2000, 2001], + "location_id": [10, 20, 10], + "sub_category": [ + ["A1", "A2"], # List of sub_categories + ["B1", "B2"], + ["C1", "C2"], + ], + "val": [100, 200, 150], + "val_sd": [10, 20, 15], + } + ) + + +@pytest.fixture +def valid_pattern(): + return pd.DataFrame( + { + "year_id": [2000, 2000, 2000, 2000, 2001, 2001], + "location_id": [10, 10, 20, 20, 10, 10], + "sub_category": ["A1", "A2", "B1", "B2", "C1", "C2"], + "pattern_val": [0.6, 0.4, 0.7, 0.3, 0.55, 0.45], + "pattern_val_sd": [0.06, 0.04, 0.07, 0.03, 0.055, 0.045], + } + ) + + +@pytest.fixture +def valid_population(): + return pd.DataFrame( + { + "year_id": [2000, 2000, 2000, 2000, 2001, 2001], + "location_id": [10, 10, 20, 20, 10, 10], + "sub_category": ["A1", "A2", "B1", "B2", "C1", "C2"], + "population": [5000, 3000, 7000, 3000, 5500, 4500], + } + ) + + +@pytest.fixture +def cat_splitter(cat_data_config, cat_pattern_config, cat_population_config): + return CatSplitter( + data=cat_data_config, + pattern=cat_pattern_config, + population=cat_population_config, + ) + + +# Step 2: Write Tests for parse_data + + +def test_parse_data_duplicated_index(cat_splitter, valid_data): + """Test parse_data raises an error on duplicated index.""" + duplicated_data = pd.concat([valid_data, valid_data]) + with pytest.raises(ValueError, match="has duplicated index"): + cat_splitter.parse_data(duplicated_data, positive_strict=True) + + +def test_parse_data_valid(cat_splitter, valid_data): + """Test that parse_data works correctly on valid data.""" + parsed_data = cat_splitter.parse_data(valid_data, positive_strict=True) + assert not parsed_data.empty + assert "val" in parsed_data.columns + assert "val_sd" in parsed_data.columns + + +# Step 3: Write Tests for parse_pattern + + +def test_parse_pattern_valid(cat_splitter, valid_data, valid_pattern): + """Test that parse_pattern works correctly on valid data.""" + parsed_data = cat_splitter.parse_data(valid_data, positive_strict=True) + parsed_pattern = cat_splitter.parse_pattern( + parsed_data, valid_pattern, model="rate" + ) + assert not parsed_pattern.empty + # The pattern columns are renamed with prefix 'cat_pat_' + assert ( + f"{cat_splitter.pattern.prefix}{cat_splitter.pattern.val}" + in parsed_pattern.columns + ) + assert ( + f"{cat_splitter.pattern.prefix}{cat_splitter.pattern.val_sd}" + in parsed_pattern.columns + ) + + +# Step 4: Write Tests for parse_population + + +def test_parse_population_missing_columns( + cat_splitter, valid_data, valid_pattern, valid_population +): + """Test parse_population raises an error when population columns are missing.""" + invalid_population = valid_population.drop(columns=["population"]) + parsed_data = cat_splitter.parse_data(valid_data, positive_strict=True) + parsed_pattern = cat_splitter.parse_pattern( + parsed_data, valid_pattern, model="rate" + ) + with pytest.raises(KeyError, match="has missing columns"): + cat_splitter.parse_population(parsed_pattern, invalid_population) + + +def test_parse_population_with_nan( + cat_splitter, valid_data, valid_pattern, valid_population +): + """Test parse_population raises an error when there are NaN values.""" + invalid_population = valid_population.copy() + invalid_population.loc[0, "population"] = None + parsed_data = cat_splitter.parse_data(valid_data, positive_strict=True) + parsed_pattern = cat_splitter.parse_pattern( + parsed_data, valid_pattern, model="rate" + ) + with pytest.raises(ValueError, match="has NaN values"): + cat_splitter.parse_population(parsed_pattern, invalid_population) + + +def test_parse_population_valid( + cat_splitter, valid_data, valid_pattern, valid_population +): + """Test that parse_population works correctly on valid data.""" + parsed_data = cat_splitter.parse_data(valid_data, positive_strict=True) + parsed_pattern = cat_splitter.parse_pattern( + parsed_data, valid_pattern, model="rate" + ) + parsed_population = cat_splitter.parse_population( + parsed_pattern, valid_population + ) + assert not parsed_population.empty + # The population column is renamed with prefix 'cat_pop_' + pop_col = f"{cat_splitter.population.prefix}{cat_splitter.population.val}" + assert pop_col in parsed_population.columns + + +# Step 5: Write Tests for the split method + + +def test_split_valid(cat_splitter, valid_data, valid_pattern, valid_population): + """Test that the split method works correctly on valid data.""" + result = cat_splitter.split( + data=valid_data, + pattern=valid_pattern, + population=valid_population, + model="rate", + output_type="rate", + ) + assert not result.empty + assert "split_result" in result.columns + assert "split_result_se" in result.columns + + +def test_split_with_invalid_output_type( + cat_splitter, valid_data, valid_pattern, valid_population +): + """Test that the split method raises an error with an invalid output_type.""" + with pytest.raises(ValueError, match="Invalid output_type"): + cat_splitter.split( + data=valid_data, + pattern=valid_pattern, + population=valid_population, + model="rate", + output_type="invalid_output", + ) + + +def test_split_with_missing_population(cat_splitter, valid_data, valid_pattern): + """Test that the split method raises an error when population data is missing.""" + with pytest.raises( + KeyError, match="Parsing Population has missing columns" + ): + cat_splitter.split( + data=valid_data, + pattern=valid_pattern, + population=pd.DataFrame(), # Empty population data + model="rate", + output_type="rate", + ) + + +def test_split_with_non_matching_categories( + cat_splitter, valid_data, valid_pattern, valid_population +): + """Test that the split method raises an error when categories don't match.""" + invalid_population = valid_population.copy() + invalid_population["sub_category"] = ["X1", "X2", "X1", "X2", "X1", "X2"] + with pytest.raises( + ValueError, + match="After merging with population, there were NaN values created", + ): + cat_splitter.split( + data=valid_data, + pattern=valid_pattern, + population=invalid_population, + model="rate", + output_type="rate", + ) diff --git a/tests/test_validator.py b/tests/test_validator.py index 7c528f7..6a3464d 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -10,9 +10,12 @@ validate_nonan, validate_pat_coverage, validate_positive, + validate_set_uniqueness, + validate_realnumber, ) +# Test functions @pytest.fixture def data(): np.random.seed(123) @@ -65,6 +68,7 @@ def population(): return population +# Tests for validate_columns def test_validate_columns_missing(population): with pytest.raises(KeyError): validate_columns( @@ -74,6 +78,16 @@ def test_validate_columns_missing(population): ) +def test_validate_columns_no_missing(population): + # All columns are present; should pass + validate_columns( + population, + ["sex_id", "location_id", "age_group_id", "year_id", "population"], + "population", + ) + + +# Tests for validate_index def test_validate_index_missing(population): with pytest.raises(ValueError): validate_index( @@ -83,11 +97,27 @@ def test_validate_index_missing(population): ) +def test_validate_index_no_duplicates(population): + # Ensure DataFrame has no duplicate indices; should pass + validate_index( + population, + ["sex_id", "location_id", "age_group_id", "year_id"], + "population", + ) + + +# Tests for validate_nonan def test_validate_nonan(population): with pytest.raises(ValueError): validate_nonan(population.assign(population=np.nan), "population") +def test_validate_nonan_no_nan(population): + # No NaN values; should pass + validate_nonan(population, "population") + + +# Tests for validate_positive def test_validate_positive_strict(population): with pytest.raises(ValueError): validate_positive( @@ -108,6 +138,12 @@ def test_validate_positive_not_strict(population): ) +def test_validate_positive_no_error(population): + validate_positive(population, ["population"], "population", strict=True) + validate_positive(population, ["population"], "population", strict=False) + + +# Tests for validate_interval def test_validate_interval_lower_equal_upper(data): with pytest.raises(ValueError): validate_interval( @@ -130,11 +166,7 @@ def test_validate_interval_positive(data): validate_interval(data, "age_start", "age_end", ["uid"], "data") -def test_validate_positive_no_error(population): - validate_positive(population, ["population"], "population", strict=True) - validate_positive(population, ["population"], "population", strict=False) - - +# Tests for validate_noindexdiff @pytest.fixture def merged_data_pattern(data, pattern): return pd.merge( @@ -149,7 +181,7 @@ def test_validate_noindexdiff_merged_positive(merged_data_pattern, population): # Positive test case: no index difference validate_noindexdiff( population, - merged_data_pattern, + merged_data_pattern.dropna(subset=["sex_id", "location_id"]), ["sex_id", "location_id"], "merged_data_pattern", ) @@ -173,6 +205,7 @@ def test_validate_noindexdiff_merged_negative(data, pattern): ) +# Tests for validate_pat_coverage @pytest.mark.parametrize( "bad_data_with_pattern", [ @@ -221,3 +254,81 @@ def test_validate_pat_coverage_failure(bad_data_with_pattern): ["group_id"], "pattern", ) + + +# Tests for validate_realnumber +def test_validate_realnumber_positive(): + df = pd.DataFrame({"col1": [1, 2.5, -3.5, 4.2], "col2": [5.1, 6, 7, 8]}) + # Should pass without exceptions + validate_realnumber(df, ["col1", "col2"], "df") + + +def test_validate_realnumber_zero(): + df = pd.DataFrame({"col1": [1, 2, 0, 4], "col2": [5, 6, 7, 8]}) + with pytest.raises( + ValueError, match="df has non-real or zero values in: \\['col1'\\]" + ): + validate_realnumber(df, ["col1"], "df") + + +def test_validate_realnumber_nan(): + df = pd.DataFrame({"col1": [1, 2, 3, np.nan], "col2": [5, 6, 7, 8]}) + with pytest.raises( + ValueError, match="df has non-real or zero values in: \\['col1'\\]" + ): + validate_realnumber(df, ["col1"], "df") + + +def test_validate_realnumber_non_numeric(): + df = pd.DataFrame({"col1": [1, 2, 3, "a"], "col2": [5, 6, 7, 8]}) + with pytest.raises( + ValueError, match="df has non-real or zero values in: \\['col1'\\]" + ): + validate_realnumber(df, ["col1"], "df") + + +def test_validate_realnumber_infinite(): + df = pd.DataFrame({"col1": [1, 2, 3, np.inf], "col2": [5, 6, 7, 8]}) + # np.inf is not a finite real number + with pytest.raises( + ValueError, match="df has non-real or zero values in: \\['col1'\\]" + ): + validate_realnumber(df, ["col1"], "df") + + +# Tests for validate_set_uniqueness +def test_validate_set_uniqueness_positive(): + df = pd.DataFrame( + {"col1": [[1, 2, 3], ["a", "b", "c"], [True, False], [1.1, 2.2, 3.3]]} + ) + # Should pass without exceptions + validate_set_uniqueness(df, "col1", "df") + + +def test_validate_set_uniqueness_negative(): + df = pd.DataFrame( + {"col1": [[1, 2, 2], ["a", "b", "a"], [True, False], [1.1, 2.2, 1.1]]} + ) + with pytest.raises( + ValueError, + match="df has rows in column 'col1' where list elements are not unique.", + ): + validate_set_uniqueness(df, "col1", "df") + + +def test_validate_set_uniqueness_empty_lists(): + df = pd.DataFrame({"col1": [[], [], []]}) + # Should pass; empty lists have no duplicates + validate_set_uniqueness(df, "col1", "df") + + +def test_validate_set_uniqueness_single_element_lists(): + df = pd.DataFrame({"col1": [[1], ["a"], [True]]}) + # Should pass; single-element lists can't have duplicates + validate_set_uniqueness(df, "col1", "df") + + +def test_validate_set_uniqueness_mixed_types_with_duplicates(): + df = pd.DataFrame({"col1": [[1, "1", 1.0], [True, 1, 1.0], [2, 2, 2]]}) + with pytest.raises(ValueError): + validate_set_uniqueness(df, "col1", "df")