diff --git a/examples/ihme_api/cat_sex_split_example.py b/examples/ihme_api/cat_sex_split_example.py
new file mode 100644
index 0000000..3629ae3
--- /dev/null
+++ b/examples/ihme_api/cat_sex_split_example.py
@@ -0,0 +1,174 @@
+import pandas as pd
+import numpy as np
+
+# Import CatSplitter and configurations from your module
+from pydisagg.ihme.splitter import (
+ CatSplitter,
+ CatDataConfig,
+ CatPatternConfig,
+ CatPopulationConfig,
+)
+
+# Set a random seed for reproducibility
+np.random.seed(42)
+
+# -------------------------------
+# 1. Create and Update data_df
+# -------------------------------
+
+# Existing data_df DataFrame
+data_df = pd.DataFrame(
+ {
+ "seq": [303284043, 303284062, 303284063, 303284064, 303284065],
+ "location_id": [78, 130, 120, 30, 141],
+ "mean": [0.5] * 5,
+ "standard_error": [0.1] * 5,
+ "year_id": [2015, 2019, 2018, 2017, 2016],
+ }
+)
+
+# Adding the 'sex' column with a list [1, 2] for each row
+data_df["sex"] = [[1, 2]] * len(data_df)
+
+# Sort data_df for clarity
+data_df_sorted = data_df.sort_values(by=["location_id"]).reset_index(drop=True)
+
+# Display the sorted data_df
+print("data_df:")
+print(data_df_sorted)
+
+# -------------------------------
+# 2. Create and Update pattern_df_final
+# -------------------------------
+
+pattern_df = pd.DataFrame(
+ {
+ "location_id": [78, 130, 120, 30, 141],
+ "mean": [0.5] * 5,
+ "standard_error": [0.1] * 5,
+ "year_id": [2015, 2019, 2018, 2017, 2016],
+ }
+)
+
+# Create DataFrame for sex=1
+pattern_df_sex1 = pattern_df.copy()
+pattern_df_sex1["sex"] = 1 # Assign sex=1
+pattern_df_sex1["mean"] += np.random.normal(0, 0.01, size=len(pattern_df_sex1))
+pattern_df_sex1["standard_error"] += np.random.normal(
+ 0, 0.001, size=len(pattern_df_sex1)
+)
+pattern_df_sex1["mean"] = pattern_df_sex1["mean"].round(6)
+pattern_df_sex1["standard_error"] = pattern_df_sex1["standard_error"].round(6)
+
+# Create DataFrame for sex=2
+pattern_df_sex2 = pattern_df.copy()
+pattern_df_sex2["sex"] = 2 # Assign sex=2
+pattern_df_sex2["mean"] += np.random.normal(0, 0.01, size=len(pattern_df_sex2))
+pattern_df_sex2["standard_error"] += np.random.normal(
+ 0, 0.001, size=len(pattern_df_sex2)
+)
+pattern_df_sex2["mean"] = pattern_df_sex2["mean"].round(6)
+pattern_df_sex2["standard_error"] = pattern_df_sex2["standard_error"].round(6)
+
+pattern_df_final = pd.concat(
+ [pattern_df_sex1, pattern_df_sex2], ignore_index=True
+)
+
+# Sort pattern_df_final for clarity
+pattern_df_final_sorted = pattern_df_final.sort_values(
+ by=["location_id", "sex"]
+).reset_index(drop=True)
+
+print("\npattern_df_final:")
+print(pattern_df_final_sorted)
+
+# -------------------------------
+# 3. Create and Update population_df
+# -------------------------------
+
+population_df = pd.DataFrame(
+ {
+ "location_id": [30, 30, 78, 78, 120, 120, 130, 130, 141, 141],
+ "year_id": [2017] * 2
+ + [2015] * 2
+ + [2018] * 2
+ + [2019] * 2
+ + [2016] * 2,
+ "sex": [1, 2] * 5, # Sexes 1 and 2
+ "population": [
+ 39789,
+ 40120,
+ 10234,
+ 10230,
+ 30245,
+ 29870,
+ 19876,
+ 19980,
+ 50234,
+ 49850,
+ ],
+ }
+)
+
+# Sort population_df for clarity
+population_df_sorted = population_df.sort_values(
+ by=["location_id", "sex"]
+).reset_index(drop=True)
+
+# Display the sorted population_df
+print("\npopulation_df:")
+print(population_df_sorted)
+
+# -------------------------------
+# 4. Configure and Run CatSplitter
+# -------------------------------
+
+# Data configuration
+data_config = CatDataConfig(
+ index=[
+ "seq",
+ "location_id",
+ "year_id",
+ "sex",
+ ], # Include 'sex' in the index
+ cat_group="sex",
+ val="mean",
+ val_sd="standard_error",
+)
+
+# Pattern configuration
+pattern_config = CatPatternConfig(
+ by=["location_id", "year_id"],
+ cat="sex",
+ val="mean",
+ val_sd="standard_error",
+)
+
+# Population configuration
+population_config = CatPopulationConfig(
+ index=["location_id", "year_id", "sex"], # Include 'sex' in the index
+ val="population",
+)
+
+# Initialize the CatSplitter
+splitter = CatSplitter(
+ data=data_config, pattern=pattern_config, population=population_config
+)
+
+# Perform the split
+try:
+ final_split_df = splitter.split(
+ data=data_df,
+ pattern=pattern_df_final,
+ population=population_df,
+ model="rate",
+ output_type="rate",
+ )
+
+ # Sort the final DataFrame by 'seq' and then by 'sex'
+ final_split_df.sort_values(by=["seq", "sex"], inplace=True)
+
+ print("\nFinal Split DataFrame:")
+ print(final_split_df)
+except Exception as e:
+ print(f"Error during splitting: {e}")
diff --git a/examples/ihme_api/cat_split.ipynb b/examples/ihme_api/cat_split.ipynb
new file mode 100644
index 0000000..e328e09
--- /dev/null
+++ b/examples/ihme_api/cat_split.ipynb
@@ -0,0 +1,401 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Pre-split DataFrame:\n",
+ " study_id year_id location_id mean std_err\n",
+ "0 8270 2010 [1234, 1235, 1236] 0.2 0.01\n",
+ "1 1860 2010 [2345, 2346, 2347] 0.3 0.02\n",
+ "2 6390 2010 [3456] 0.4 0.03\n",
+ "\n",
+ "Pattern DataFrame:\n",
+ " location_id year_id mean std_err\n",
+ "0 1234 2010 0.392798 0.048796\n",
+ "1 1235 2010 0.339463 0.043298\n",
+ "2 1236 2010 0.162407 0.018494\n",
+ "3 2345 2010 0.162398 0.017273\n",
+ "4 2346 2010 0.123233 0.017336\n",
+ "5 2347 2010 0.446470 0.022170\n",
+ "6 3456 2010 0.340446 0.030990\n",
+ "7 4567 2010 0.383229 0.027278\n",
+ "8 5678 2010 0.108234 0.021649\n",
+ "\n",
+ "Population DataFrame:\n",
+ " location_id year_id population\n",
+ "0 1234 2010 166730\n",
+ "1 1235 2010 880910\n",
+ "2 1236 2010 394681\n",
+ "3 2345 2010 159503\n",
+ "4 2346 2010 664811\n",
+ "5 2347 2010 537035\n",
+ "6 3456 2010 658143\n",
+ "7 4567 2010 462366\n",
+ "8 5678 2010 75725\n",
+ "\n",
+ "Final Split DataFrame:\n",
+ " mean study_id std_err location_id year_id cat_pat_mean \\\n",
+ "3 0.3 1860 0.02 2345 2010 0.162398 \n",
+ "4 0.3 1860 0.02 2346 2010 0.123233 \n",
+ "5 0.3 1860 0.02 2347 2010 0.446470 \n",
+ "6 0.4 6390 0.03 3456 2010 0.340446 \n",
+ "0 0.2 8270 0.01 1234 2010 0.392798 \n",
+ "1 0.2 8270 0.01 1235 2010 0.339463 \n",
+ "2 0.2 8270 0.01 1236 2010 0.162407 \n",
+ "\n",
+ " cat_pat_std_err population split_result split_result_se split_flag \\\n",
+ "3 0.017273 159503.0 0.190806 0.024440 1 \n",
+ "4 0.017336 664811.0 0.144790 0.019012 1 \n",
+ "5 0.022170 537035.0 0.524570 0.040101 1 \n",
+ "6 0.030990 658143.0 0.400000 0.030000 0 \n",
+ "0 0.048796 166730.0 0.264351 0.039018 1 \n",
+ "1 0.043298 880910.0 0.228457 0.015557 1 \n",
+ "2 0.018494 394681.0 0.109300 0.015518 1 \n",
+ "\n",
+ " orig_group \n",
+ "3 [2345, 2346, 2347] \n",
+ "4 [2345, 2346, 2347] \n",
+ "5 [2345, 2346, 2347] \n",
+ "6 [3456] \n",
+ "0 [1234, 1235, 1236] \n",
+ "1 [1234, 1235, 1236] \n",
+ "2 [1234, 1235, 1236] \n"
+ ]
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "\n",
+ "# Assuming the CatSplitter and configuration classes have been imported correctly\n",
+ "from pydisagg.ihme.splitter import (\n",
+ " CatSplitter,\n",
+ " CatDataConfig,\n",
+ " CatPatternConfig,\n",
+ " CatPopulationConfig,\n",
+ ")\n",
+ "\n",
+ "# Set a random seed for reproducibility\n",
+ "np.random.seed(42)\n",
+ "\n",
+ "# -------------------------------\n",
+ "# Example DataFrames\n",
+ "# -------------------------------\n",
+ "\n",
+ "# Pre-split DataFrame with 3 rows\n",
+ "pre_split = pd.DataFrame(\n",
+ " {\n",
+ " \"study_id\": np.random.randint(1000, 9999, size=3), # Unique study IDs\n",
+ " \"year_id\": [2010, 2010, 2010],\n",
+ " \"location_id\": [\n",
+ " [1234, 1235, 1236], # List of location_ids for row 1\n",
+ " [2345, 2346, 2347], # List of location_ids for row 2\n",
+ " [3456], # Single location_id for row 3 (no need to split)\n",
+ " ],\n",
+ " \"mean\": [0.2, 0.3, 0.4],\n",
+ " \"std_err\": [0.01, 0.02, 0.03],\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "# Create a list of all location_ids mentioned\n",
+ "all_location_ids = [\n",
+ " 1234,\n",
+ " 1235,\n",
+ " 1236,\n",
+ " 2345,\n",
+ " 2346,\n",
+ " 2347,\n",
+ " 3456,\n",
+ " 4567, # Additional location_ids\n",
+ " 5678,\n",
+ "]\n",
+ "\n",
+ "# Pattern DataFrame for all location_ids\n",
+ "data_pattern = pd.DataFrame(\n",
+ " {\n",
+ " \"location_id\": all_location_ids,\n",
+ " \"year_id\": [2010] * len(all_location_ids),\n",
+ " \"mean\": np.random.uniform(0.1, 0.5, len(all_location_ids)),\n",
+ " \"std_err\": np.random.uniform(0.01, 0.05, len(all_location_ids)),\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "# Population DataFrame for all location_ids\n",
+ "data_pop = pd.DataFrame(\n",
+ " {\n",
+ " \"location_id\": all_location_ids,\n",
+ " \"year_id\": [2010] * len(all_location_ids),\n",
+ " \"population\": np.random.randint(10000, 1000000, len(all_location_ids)),\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "# Print the DataFrames\n",
+ "print(\"Pre-split DataFrame:\")\n",
+ "print(pre_split)\n",
+ "print(\"\\nPattern DataFrame:\")\n",
+ "print(data_pattern)\n",
+ "print(\"\\nPopulation DataFrame:\")\n",
+ "print(data_pop)\n",
+ "\n",
+ "# -------------------------------\n",
+ "# Configurations\n",
+ "# -------------------------------\n",
+ "\n",
+ "data_config = CatDataConfig(\n",
+ " index=[\"study_id\", \"year_id\"], # Include study_id in the index\n",
+ " target=\"location_id\", # Column containing list of targets\n",
+ " val=\"mean\",\n",
+ " val_sd=\"std_err\",\n",
+ ")\n",
+ "\n",
+ "pattern_config = CatPatternConfig(\n",
+ " index=[\"year_id\"],\n",
+ " target=\"location_id\",\n",
+ " val=\"mean\",\n",
+ " val_sd=\"std_err\",\n",
+ ")\n",
+ "\n",
+ "population_config = CatPopulationConfig(\n",
+ " index=[\"year_id\"],\n",
+ " target=\"location_id\",\n",
+ " val=\"population\",\n",
+ ")\n",
+ "\n",
+ "# Initialize the CatSplitter\n",
+ "splitter = CatSplitter(\n",
+ " data=data_config, pattern=pattern_config, population=population_config\n",
+ ")\n",
+ "\n",
+ "# Perform the split\n",
+ "try:\n",
+ " final_split_df = splitter.split(\n",
+ " data=pre_split,\n",
+ " pattern=data_pattern,\n",
+ " population=data_pop,\n",
+ " model=\"rate\",\n",
+ " output_type=\"rate\",\n",
+ " )\n",
+ " final_split_df.sort_values(by=[\"study_id\", \"location_id\"], inplace=True)\n",
+ " print(\"\\nFinal Split DataFrame:\")\n",
+ " print(final_split_df)\n",
+ "except ValueError as e:\n",
+ " print(f\"Error: {e}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " mean | \n",
+ " study_id | \n",
+ " std_err | \n",
+ " location_id | \n",
+ " year_id | \n",
+ " cat_pat_mean | \n",
+ " cat_pat_std_err | \n",
+ " population | \n",
+ " split_result | \n",
+ " split_result_se | \n",
+ " split_flag | \n",
+ " orig_group | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 3 | \n",
+ " 0.3 | \n",
+ " 1860 | \n",
+ " 0.02 | \n",
+ " 2345 | \n",
+ " 2010 | \n",
+ " 0.162398 | \n",
+ " 0.017273 | \n",
+ " 159503.0 | \n",
+ " 0.190806 | \n",
+ " 0.024440 | \n",
+ " 1 | \n",
+ " [2345, 2346, 2347] | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.3 | \n",
+ " 1860 | \n",
+ " 0.02 | \n",
+ " 2346 | \n",
+ " 2010 | \n",
+ " 0.123233 | \n",
+ " 0.017336 | \n",
+ " 664811.0 | \n",
+ " 0.144790 | \n",
+ " 0.019012 | \n",
+ " 1 | \n",
+ " [2345, 2346, 2347] | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.3 | \n",
+ " 1860 | \n",
+ " 0.02 | \n",
+ " 2347 | \n",
+ " 2010 | \n",
+ " 0.446470 | \n",
+ " 0.022170 | \n",
+ " 537035.0 | \n",
+ " 0.524570 | \n",
+ " 0.040101 | \n",
+ " 1 | \n",
+ " [2345, 2346, 2347] | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.4 | \n",
+ " 6390 | \n",
+ " 0.03 | \n",
+ " 3456 | \n",
+ " 2010 | \n",
+ " 0.340446 | \n",
+ " 0.030990 | \n",
+ " 658143.0 | \n",
+ " 0.400000 | \n",
+ " 0.030000 | \n",
+ " 0 | \n",
+ " [3456] | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 0.2 | \n",
+ " 8270 | \n",
+ " 0.01 | \n",
+ " 1234 | \n",
+ " 2010 | \n",
+ " 0.392798 | \n",
+ " 0.048796 | \n",
+ " 166730.0 | \n",
+ " 0.264351 | \n",
+ " 0.039018 | \n",
+ " 1 | \n",
+ " [1234, 1235, 1236] | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.2 | \n",
+ " 8270 | \n",
+ " 0.01 | \n",
+ " 1235 | \n",
+ " 2010 | \n",
+ " 0.339463 | \n",
+ " 0.043298 | \n",
+ " 880910.0 | \n",
+ " 0.228457 | \n",
+ " 0.015557 | \n",
+ " 1 | \n",
+ " [1234, 1235, 1236] | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.2 | \n",
+ " 8270 | \n",
+ " 0.01 | \n",
+ " 1236 | \n",
+ " 2010 | \n",
+ " 0.162407 | \n",
+ " 0.018494 | \n",
+ " 394681.0 | \n",
+ " 0.109300 | \n",
+ " 0.015518 | \n",
+ " 1 | \n",
+ " [1234, 1235, 1236] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " mean study_id std_err location_id year_id cat_pat_mean \\\n",
+ "3 0.3 1860 0.02 2345 2010 0.162398 \n",
+ "4 0.3 1860 0.02 2346 2010 0.123233 \n",
+ "5 0.3 1860 0.02 2347 2010 0.446470 \n",
+ "6 0.4 6390 0.03 3456 2010 0.340446 \n",
+ "0 0.2 8270 0.01 1234 2010 0.392798 \n",
+ "1 0.2 8270 0.01 1235 2010 0.339463 \n",
+ "2 0.2 8270 0.01 1236 2010 0.162407 \n",
+ "\n",
+ " cat_pat_std_err population split_result split_result_se split_flag \\\n",
+ "3 0.017273 159503.0 0.190806 0.024440 1 \n",
+ "4 0.017336 664811.0 0.144790 0.019012 1 \n",
+ "5 0.022170 537035.0 0.524570 0.040101 1 \n",
+ "6 0.030990 658143.0 0.400000 0.030000 0 \n",
+ "0 0.048796 166730.0 0.264351 0.039018 1 \n",
+ "1 0.043298 880910.0 0.228457 0.015557 1 \n",
+ "2 0.018494 394681.0 0.109300 0.015518 1 \n",
+ "\n",
+ " orig_group \n",
+ "3 [2345, 2346, 2347] \n",
+ "4 [2345, 2346, 2347] \n",
+ "5 [2345, 2346, 2347] \n",
+ "6 [3456] \n",
+ "0 [1234, 1235, 1236] \n",
+ "1 [1234, 1235, 1236] \n",
+ "2 [1234, 1235, 1236] "
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "final_split_df"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pyDis-mac",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/ihme_api/cat_split_example.py b/examples/ihme_api/cat_split_example.py
new file mode 100644
index 0000000..94dbb6e
--- /dev/null
+++ b/examples/ihme_api/cat_split_example.py
@@ -0,0 +1,126 @@
+# cat_split_example.py
+
+import numpy as np
+import pandas as pd
+from pandas import DataFrame
+
+# Import CatSplitter and related classes
+from pydisagg.ihme.splitter import (
+ CatSplitter,
+ CatDataConfig,
+ CatPatternConfig,
+ CatPopulationConfig,
+)
+
+# -------------------------------
+# Example DataFrames
+# -------------------------------
+
+# Set a random seed for reproducibility
+np.random.seed(42)
+
+# Pre-split DataFrame with 3 rows
+pre_split = pd.DataFrame(
+ {
+ "study_id": np.random.randint(1000, 9999, size=3), # Unique study IDs
+ "year_id": [2010, 2010, 2010],
+ "location_id": [
+ [1234, 1235, 1236], # List of location_ids for row 1
+ [2345, 2346, 2347], # List of location_ids for row 2
+ [3456], # Single location_id for row 3 (no need to split)
+ ],
+ "mean": [0.2, 0.3, 0.4],
+ "std_err": [0.01, 0.02, 0.03],
+ }
+)
+
+# Create a list of all location_ids mentioned
+all_location_ids = [
+ 1234,
+ 1235,
+ 1236,
+ 2345,
+ 2346,
+ 2347,
+ 3456,
+ 4567,
+ 5678, # Additional location_ids
+]
+
+# Pattern DataFrame for all location_ids
+data_pattern = pd.DataFrame(
+ {
+ "year_id": [2010] * len(all_location_ids),
+ "location_id": all_location_ids,
+ "mean": np.random.uniform(0.1, 0.5, len(all_location_ids)),
+ "std_err": np.random.uniform(0.01, 0.05, len(all_location_ids)),
+ }
+)
+
+# Population DataFrame for all location_ids
+data_pop = pd.DataFrame(
+ {
+ "year_id": [2010] * len(all_location_ids),
+ "location_id": all_location_ids,
+ "population": np.random.randint(10000, 1000000, len(all_location_ids)),
+ }
+)
+
+# Print the DataFrames
+print("Pre-split DataFrame:")
+print(pre_split)
+print("\nPattern DataFrame:")
+print(data_pattern)
+print("\nPopulation DataFrame:")
+print(data_pop)
+
+# -------------------------------
+# Configurations
+# -------------------------------
+
+# Adjusted configurations to match the modified CatSplitter
+data_config = CatDataConfig(
+ index=[
+ "study_id",
+ "year_id",
+ "location_id",
+ ], # Include 'location_id' in the index
+ cat_group="location_id",
+ val="mean",
+ val_sd="std_err",
+)
+
+pattern_config = CatPatternConfig(
+ by=["year_id"],
+ cat="location_id",
+ val="mean",
+ val_sd="std_err",
+)
+
+population_config = CatPopulationConfig(
+ index=["year_id", "location_id"],
+ val="population",
+)
+
+# Initialize the CatSplitter with the updated configurations
+splitter = CatSplitter(
+ data=data_config,
+ pattern=pattern_config,
+ population=population_config,
+)
+
+# Perform the split
+try:
+ final_split_df = splitter.split(
+ data=pre_split,
+ pattern=data_pattern,
+ population=data_pop,
+ model="rate",
+ output_type="rate",
+ )
+ # Sort the final DataFrame for better readability
+ final_split_df.sort_values(by=["study_id", "location_id"], inplace=True)
+ print("\nFinal Split DataFrame:")
+ print(final_split_df)
+except Exception as e:
+ print(f"Error during splitting: {e}")
diff --git a/pyproject.toml b/pyproject.toml
index bc768b0..90dec5c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,7 +6,7 @@ requires = [
[project]
name = "pydisagg"
-version = "0.5.2"
+version = "0.6.0"
description = ""
readme = "README.md"
license = { text = "BSD 2-Clause License" }
diff --git a/src/pydisagg/ihme/splitter/__init__.py b/src/pydisagg/ihme/splitter/__init__.py
index 2687c09..c954cdf 100644
--- a/src/pydisagg/ihme/splitter/__init__.py
+++ b/src/pydisagg/ihme/splitter/__init__.py
@@ -10,6 +10,12 @@
SexPatternConfig,
SexPopulationConfig,
)
+from .cat_splitter import (
+ CatSplitter,
+ CatDataConfig,
+ CatPatternConfig,
+ CatPopulationConfig,
+)
__all__ = [
"AgeSplitter",
@@ -20,4 +26,8 @@
"SexDataConfig",
"SexPatternConfig",
"SexPopulationConfig",
+ "CatSplitter",
+ "CatDataConfig",
+ "CatPatternConfig",
+ "CatPopulationConfig",
]
diff --git a/src/pydisagg/ihme/splitter/age_splitter.py b/src/pydisagg/ihme/splitter/age_splitter.py
index bd389dc..e0f255e 100644
--- a/src/pydisagg/ihme/splitter/age_splitter.py
+++ b/src/pydisagg/ihme/splitter/age_splitter.py
@@ -197,7 +197,7 @@ def parse_pattern(
def _merge_with_pattern(
self, data: DataFrame, pattern: DataFrame
) -> DataFrame:
- # Ensure the necessary columns are present before merging
+ # TODO change these asserts to validate_columns
assert (
self.data.age_lwr in data.columns
), f"Column '{self.data.age_lwr}' not found in data"
@@ -236,11 +236,10 @@ def parse_population(
validate_index(population, self.population.index, name)
validate_nonan(population, name)
- pop_copy = population.copy()
rename_map = self.population.apply_prefix()
- pop_copy.rename(columns=rename_map, inplace=True)
+ population.rename(columns=rename_map, inplace=True)
- data_with_population = self._merge_with_population(data, pop_copy)
+ data_with_population = self._merge_with_population(data, population)
validate_noindexdiff(
data,
diff --git a/src/pydisagg/ihme/splitter/cat_splitter.py b/src/pydisagg/ihme/splitter/cat_splitter.py
new file mode 100644
index 0000000..2ee7545
--- /dev/null
+++ b/src/pydisagg/ihme/splitter/cat_splitter.py
@@ -0,0 +1,330 @@
+# cat_splitter.py
+
+from typing import Any, List, Literal
+import numpy as np
+import pandas as pd
+from pandas import DataFrame
+from pydantic import BaseModel
+
+from pydisagg.disaggregate import split_datapoint
+from pydisagg.models import RateMultiplicativeModel, LogOddsModel
+from pydisagg.ihme.schema import Schema
+from pydisagg.ihme.validator import (
+ validate_columns,
+ validate_index,
+ validate_noindexdiff,
+ validate_nonan,
+ validate_positive,
+ validate_set_uniqueness,
+)
+
+
+class CatDataConfig(Schema):
+ """
+ Configuration schema for categorical data DataFrame.
+ """
+
+ index: List[str]
+ cat_group: str
+ val: str
+ val_sd: str
+
+ @property
+ def columns(self) -> List[str]:
+ return self.index + [self.cat_group, self.val, self.val_sd]
+
+ @property
+ def val_fields(self) -> List[str]:
+ return [self.val, self.val_sd]
+
+
+class CatPatternConfig(Schema):
+ """
+ Configuration schema for the pattern DataFrame.
+ """
+
+ by: List[str]
+ cat: str
+ draws: List[str] = []
+ val: str = "mean"
+ val_sd: str = "std_err"
+ prefix: str = "cat_pat_"
+
+ @property
+ def index(self) -> List[str]:
+ return self.by + [self.cat]
+
+ @property
+ def columns(self) -> List[str]:
+ return self.index + self.val_fields + self.draws
+
+ @property
+ def val_fields(self) -> List[str]:
+ return [self.val, self.val_sd]
+
+
+class CatPopulationConfig(Schema):
+ """
+ Configuration for the population DataFrame.
+ """
+
+ index: List[str]
+ val: str
+ prefix: str = "cat_pop_"
+
+ @property
+ def columns(self) -> List[str]:
+ return self.index + [self.val]
+
+ @property
+ def val_fields(self) -> List[str]:
+ return [self.val]
+
+
+class CatSplitter(BaseModel):
+ """
+ Class for splitting categorical data based on pattern and population data.
+ """
+
+ data: CatDataConfig
+ pattern: CatPatternConfig
+ population: CatPopulationConfig
+
+ def model_post_init(self, __context: Any) -> None:
+ """
+ Perform extra validation after model initialization.
+ """
+ if not set(self.pattern.index).issubset(self.data.index):
+ raise ValueError(
+ "The pattern's match criteria must be a subset of the data."
+ )
+ if not set(self.population.index).issubset(
+ self.data.index + self.pattern.index
+ ):
+ raise ValueError(
+ "The population's match criteria must be a subset of the data and the pattern."
+ )
+ if self.pattern.cat not in self.population.index:
+ raise ValueError(
+ "The 'target' column in the population must match the 'target' column in the data."
+ )
+
+ def parse_data(self, data: DataFrame, positive_strict: bool) -> DataFrame:
+ """
+ Parse and validate the input data DataFrame.
+ """
+ name = "While parsing data"
+
+ # Validate required columns
+ validate_columns(data, self.data.columns, name)
+
+ # Ensure that 'cat_group' column contains lists
+ data[self.data.cat_group] = data[self.data.cat_group].apply(
+ lambda x: x if isinstance(x, list) else [x]
+ )
+
+ # Validate that every list in 'cat_group' contains unique elements
+ validate_set_uniqueness(data, self.data.cat_group, name)
+
+ # Explode the 'cat_group' column and rename it to match the pattern's 'cat'
+ data = data.explode(self.data.cat_group).rename(
+ columns={self.data.cat_group: self.pattern.cat}
+ )
+
+ # Validate index after exploding
+ validate_index(data, self.data.index, name)
+ validate_nonan(data, name)
+ validate_positive(data, [self.data.val_sd], name, strict=positive_strict)
+
+ return data
+
+ def _merge_with_pattern(
+ self,
+ data: DataFrame,
+ pattern: DataFrame,
+ ) -> DataFrame:
+ """
+ Merge data with pattern DataFrame.
+ """
+ data_with_pattern = data.merge(pattern, on=self.pattern.index, how="left")
+
+ validate_nonan(
+ data_with_pattern[
+ [f"{self.pattern.prefix}{col}" for col in self.pattern.val_fields]
+ ],
+ "After merging with pattern, there were NaN values created. This indicates that your pattern does not cover all the data.",
+ )
+
+ return data_with_pattern
+
+ def parse_pattern(
+ self, data: DataFrame, pattern: DataFrame, model: str
+ ) -> DataFrame:
+ """
+ Parse and merge the pattern DataFrame with data.
+ """
+ name = "While parsing pattern"
+
+ try:
+ val_cols = self.pattern.val_fields
+ if not all(col in pattern.columns for col in val_cols):
+ if not self.pattern.draws:
+ raise ValueError(
+ f"{name}: Must provide draws for pattern if pattern.val and "
+ "pattern.val_sd are not available."
+ )
+ validate_columns(pattern, self.pattern.draws, name)
+ pattern[self.pattern.val] = pattern[self.pattern.draws].mean(axis=1)
+ pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(axis=1)
+
+ validate_columns(pattern, self.pattern.columns, name)
+ except KeyError as e:
+ raise KeyError(f"{name}: Missing columns in the pattern. Details:\n{e}")
+
+ pattern_copy = pattern.copy()
+ pattern_copy = pattern_copy[self.pattern.index + self.pattern.val_fields]
+ rename_map = {
+ col: f"{self.pattern.prefix}{col}" for col in self.pattern.val_fields
+ }
+ pattern_copy.rename(columns=rename_map, inplace=True)
+
+ # Merge with pattern
+ data_with_pattern = self._merge_with_pattern(data, pattern_copy)
+
+ # Validate index differences after merging
+ validate_noindexdiff(
+ data,
+ data_with_pattern,
+ self.data.index,
+ name,
+ )
+
+ return data_with_pattern
+
+ def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame:
+ name = "Parsing Population"
+ validate_columns(population, self.population.columns, name)
+
+ population = population[self.population.columns].copy()
+
+ validate_index(population, self.population.index, name)
+ validate_nonan(population, name)
+
+ rename_map = {
+ self.population.val: f"{self.population.prefix}{self.population.val}"
+ }
+ population.rename(columns=rename_map, inplace=True)
+
+ data_with_population = self._merge_with_population(data, population)
+
+ # Ensure the prefixed population column exists
+ pop_col = f"{self.population.prefix}{self.population.val}"
+ if pop_col not in data_with_population.columns:
+ raise KeyError(f"Expected column '{pop_col}' not found in merged data.")
+
+ validate_nonan(
+ data_with_population[[pop_col]],
+ "After merging with population, there were NaN values created. This indicates that your population data does not cover all the data.",
+ )
+ return data_with_population
+
+ def _merge_with_population(
+ self, data: DataFrame, population: DataFrame
+ ) -> DataFrame:
+ """
+ Merge data with population DataFrame.
+ """
+ data_with_population = data.merge(
+ population, on=self.population.index, how="left"
+ )
+
+ return data_with_population
+
+ def _process_group(
+ self, group: DataFrame, model: str, output_type: str
+ ) -> DataFrame:
+ """
+ Process a group of data for splitting.
+ """
+ observed_total = group[self.data.val].iloc[0]
+ observed_total_se = group[self.data.val_sd].iloc[0]
+
+ if len(group) == 1:
+ # No need to split, assign the observed values
+ group["split_result"] = observed_total
+ group["split_result_se"] = observed_total_se
+ group["split_flag"] = 0 # Not split
+ else:
+ # Need to split among multiple targets
+ bucket_populations = group[
+ f"{self.population.prefix}{self.population.val}"
+ ].values
+ rate_pattern = group[f"{self.pattern.prefix}{self.pattern.val}"].values
+ pattern_sd = group[f"{self.pattern.prefix}{self.pattern.val_sd}"].values
+ pattern_covariance = np.diag(pattern_sd**2)
+
+ if model == "rate":
+ splitting_model = RateMultiplicativeModel()
+ elif model == "logodds":
+ splitting_model = LogOddsModel()
+
+ # Determine whether to normalize by population for the output type
+ pop_normalize = output_type == "rate"
+
+ # Perform splitting
+ split_result, split_se = split_datapoint(
+ observed_total=observed_total,
+ bucket_populations=bucket_populations,
+ rate_pattern=rate_pattern,
+ model=splitting_model,
+ output_type=output_type,
+ normalize_pop_for_average_type_obs=pop_normalize,
+ observed_total_se=observed_total_se,
+ pattern_covariance=pattern_covariance,
+ )
+
+ # Assign results back to the group
+ group["split_result"] = split_result
+ group["split_result_se"] = split_se
+ group["split_flag"] = 1 # Split
+
+ return group
+
+ def split(
+ self,
+ data: DataFrame,
+ pattern: DataFrame,
+ population: DataFrame,
+ model: Literal["rate", "logodds"] = "rate",
+ output_type: Literal["rate", "count"] = "rate",
+ ) -> DataFrame:
+ """
+ Split the input data based on a specified pattern and population model.
+ """
+ # Validate model and output_type
+ if model not in ["rate", "logodds"]:
+ raise ValueError(f"Invalid model: {model}")
+ if output_type not in ["rate", "count"]:
+ raise ValueError(f"Invalid output_type: {output_type}")
+
+ if self.population.prefix_status == "prefixed":
+ self.population.remove_prefix()
+ if self.pattern.prefix_status == "prefixed":
+ self.pattern.remove_prefix()
+
+ # Parsing input data, pattern, and population
+ data = self.parse_data(data, positive_strict=True)
+ data = self.parse_pattern(data, pattern, model)
+ data = self.parse_population(data, population)
+
+ # Determine grouping columns
+ group_cols = self.data.index[:-1] # Exclude 'location_id' from grouping
+
+ # Process groups using regular groupby
+ final_split_df = (
+ data.groupby(group_cols, group_keys=False)
+ .apply(lambda group: self._process_group(group, model, output_type))
+ .reset_index(drop=True)
+ )
+
+ return final_split_df
diff --git a/src/pydisagg/ihme/splitter/sex_splitter.py b/src/pydisagg/ihme/splitter/sex_splitter.py
index 4ebd359..b5008ee 100644
--- a/src/pydisagg/ihme/splitter/sex_splitter.py
+++ b/src/pydisagg/ihme/splitter/sex_splitter.py
@@ -89,8 +89,12 @@ def model_post_init(self, __context: Any) -> None:
"population.index must be a subset of data.index + pattern.index"
)
- def _merge_with_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame:
- data_with_pattern = data.merge(pattern, on=self.pattern.by, how="left").dropna()
+ def _merge_with_pattern(
+ self, data: DataFrame, pattern: DataFrame
+ ) -> DataFrame:
+ data_with_pattern = data.merge(
+ pattern, on=self.pattern.by, how="left"
+ ).dropna()
return data_with_pattern
def get_population_by_sex(self, population, sex_value):
@@ -105,7 +109,9 @@ def parse_data(self, data: DataFrame) -> DataFrame:
try:
validate_columns(data, self.data.columns, name)
except KeyError as e:
- raise KeyError(f"{name}: Missing columns in the input data. Details:\n{e}")
+ raise KeyError(
+ f"{name}: Missing columns in the input data. Details:\n{e}"
+ )
if self.population.sex not in data.columns:
raise KeyError(
@@ -147,12 +153,18 @@ def parse_pattern(
"pattern.val_sd are not available."
)
validate_columns(pattern, self.pattern.draws, name)
- pattern[self.pattern.val] = pattern[self.pattern.draws].mean(axis=1)
- pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(axis=1)
+ pattern[self.pattern.val] = pattern[self.pattern.draws].mean(
+ axis=1
+ )
+ pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(
+ axis=1
+ )
validate_columns(pattern, self.pattern.columns, name)
except KeyError as e:
- raise KeyError(f"{name}: Missing columns in the pattern. Details:\n{e}")
+ raise KeyError(
+ f"{name}: Missing columns in the pattern. Details:\n{e}"
+ )
pattern = pattern[self.pattern.columns].copy()
@@ -166,7 +178,9 @@ def parse_pattern(
try:
validate_nonan(pattern, name)
except ValueError as e:
- raise ValueError(f"{name}: NaN values found in the pattern. Details:\n{e}")
+ raise ValueError(
+ f"{name}: NaN values found in the pattern. Details:\n{e}"
+ )
if model == "rate":
try:
@@ -196,7 +210,9 @@ def parse_pattern(
return data_with_pattern
- def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame:
+ def parse_population(
+ self, data: DataFrame, population: DataFrame
+ ) -> DataFrame:
name = "While parsing population"
# Step 1: Validate population columns
@@ -208,13 +224,19 @@ def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame:
)
# Step 2: Get male and female populations and rename columns
- male_population = self.get_population_by_sex(population, self.population.sex_m)
+ male_population = self.get_population_by_sex(
+ population, self.population.sex_m
+ )
female_population = self.get_population_by_sex(
population, self.population.sex_f
)
- male_population.rename(columns={self.population.val: "m_pop"}, inplace=True)
- female_population.rename(columns={self.population.val: "f_pop"}, inplace=True)
+ male_population.rename(
+ columns={self.population.val: "m_pop"}, inplace=True
+ )
+ female_population.rename(
+ columns={self.population.val: "f_pop"}, inplace=True
+ )
# Step 3: Merge population data with main data
data_with_population = self._merge_with_population(
@@ -242,7 +264,9 @@ def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame:
# Step 6: Validate index differences
try:
- validate_noindexdiff(data, data_with_population, self.data.index, name)
+ validate_noindexdiff(
+ data, data_with_population, self.data.index, name
+ )
except ValueError as e:
raise ValueError(
f"{name}: Index differences found between data and population. Details:\n{e}"
@@ -250,8 +274,12 @@ def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame:
# Ensure the columns are in the correct numeric type (e.g., float64)
# Convert "m_pop" and "f_pop" columns to standard numeric types if necessary
- data_with_population["m_pop"] = data_with_population["m_pop"].astype("float64")
- data_with_population["f_pop"] = data_with_population["f_pop"].astype("float64")
+ data_with_population["m_pop"] = data_with_population["m_pop"].astype(
+ "float64"
+ )
+ data_with_population["f_pop"] = data_with_population["f_pop"].astype(
+ "float64"
+ )
return data_with_population
@@ -268,9 +296,9 @@ def _merge_with_population(
# Ensure the merged population columns are standard numeric types
if pop_col in data_with_population.columns:
- data_with_population[pop_col] = data_with_population[pop_col].astype(
- "float64"
- )
+ data_with_population[pop_col] = data_with_population[
+ pop_col
+ ].astype("float64")
return data_with_population
@@ -376,12 +404,16 @@ def split(
lambda row: split_datapoint(
observed_total=row[self.data.val],
bucket_populations=np.array([row["m_pop"], row["f_pop"]]),
- rate_pattern=input_patterns[split_data.index.get_loc(row.name)],
+ rate_pattern=input_patterns[
+ split_data.index.get_loc(row.name)
+ ],
model=splitting_model,
output_type=output_type,
normalize_pop_for_average_type_obs=pop_normalize,
observed_total_se=row[self.data.val_sd],
- pattern_covariance=np.diag([0, row[self.pattern.val_sd] ** 2]),
+ pattern_covariance=np.diag(
+ [0, row[self.pattern.val_sd] ** 2]
+ ),
),
axis=1,
)
@@ -412,7 +444,11 @@ def split(
# Reindex columns
final_split_df = final_split_df.reindex(
columns=self.data.index
- + [col for col in final_split_df.columns if col not in self.data.index]
+ + [
+ col
+ for col in final_split_df.columns
+ if col not in self.data.index
+ ]
)
# Clean up any prefixes added earlier
diff --git a/src/pydisagg/ihme/validator.py b/src/pydisagg/ihme/validator.py
index 1c86379..550b253 100644
--- a/src/pydisagg/ihme/validator.py
+++ b/src/pydisagg/ihme/validator.py
@@ -4,6 +4,23 @@
def validate_columns(df: DataFrame, columns: list[str], name: str) -> None:
+ """
+ Validates that all specified columns are present in the DataFrame.
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ The DataFrame to validate.
+ columns : list of str
+ A list of expected column names that should be present in the DataFrame.
+ name : str
+ A name for the DataFrame, used in error messages.
+
+ Raises
+ ------
+ KeyError
+ If any of the specified columns are missing from the DataFrame.
+ """
missing = [col for col in columns if col not in df.columns]
if missing:
error_message = (
@@ -18,11 +35,30 @@ def validate_columns(df: DataFrame, columns: list[str], name: str) -> None:
def validate_index(df: DataFrame, index: list[str], name: str) -> None:
+ """
+ Validates that the DataFrame does not contain duplicate indices based on specified columns.
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ The DataFrame to validate.
+ index : list of str
+ A list of column names to be used as the index for validation.
+ name : str
+ A name for the DataFrame, used in error messages.
+
+ Raises
+ ------
+ ValueError
+ If duplicate indices are found in the DataFrame based on the specified columns.
+ """
duplicated_index = pd.MultiIndex.from_frame(
df[df[index].duplicated()][index]
).to_list()
if duplicated_index:
- error_message = f"{name} has duplicated index with {len(duplicated_index)} indices \n"
+ error_message = (
+ f"{name} has duplicated index with {len(duplicated_index)} indices \n"
+ )
error_message += f"Index columns: ({', '.join(index)})\n"
if len(duplicated_index) > 5:
error_message += "First 5: \n"
@@ -32,11 +68,24 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None:
def validate_nonan(df: DataFrame, name: str) -> None:
+ """
+ Validates that the DataFrame does not contain any NaN values.
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ The DataFrame to validate.
+ name : str
+ A name for the DataFrame, used in error messages.
+
+ Raises
+ ------
+ ValueError
+ If any NaN values are found in the DataFrame.
+ """
nan_columns = df.columns[df.isna().any(axis=0)].to_list()
if nan_columns:
- error_message = (
- f"{name} has NaN values in {len(nan_columns)} columns. \n"
- )
+ error_message = f"{name} has NaN values in {len(nan_columns)} columns. \n"
error_message += f"Columns with NaN values: {', '.join(nan_columns)}\n"
if len(nan_columns) > 5:
error_message += "First 5 columns with NaN values: \n"
@@ -48,7 +97,27 @@ def validate_nonan(df: DataFrame, name: str) -> None:
def validate_positive(
df: DataFrame, columns: list[str], name: str, strict: bool = False
) -> None:
- """Validates that observation values in cols are non-negative or strictly positive"""
+ """
+ Validates that specified columns contain non-negative or strictly positive values.
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ The DataFrame to validate.
+ columns : list of str
+ A list of column names to check for positive values.
+ name : str
+ A name for the DataFrame, used in error messages.
+ strict : bool, optional
+ If True, checks that values are strictly greater than zero.
+ If False, checks that values are greater than or equal to zero.
+ Default is False.
+
+ Raises
+ ------
+ ValueError
+ If any of the specified columns contain invalid (negative or zero) values.
+ """
op = "<=" if strict else "<"
negative = [col for col in columns if df.eval(f"{col} {op} 0").any()]
if negative:
@@ -59,11 +128,32 @@ def validate_positive(
def validate_interval(
df: DataFrame, lwr: str, upr: str, index: list[str], name: str
) -> None:
+ """
+ Validates that lower interval bounds are strictly less than upper bounds.
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ The DataFrame containing interval data to validate.
+ lwr : str
+ The name of the column representing the lower bound of the interval.
+ upr : str
+ The name of the column representing the upper bound of the interval.
+ index : list of str
+ A list of column names to be used as the index for identifying intervals.
+ name : str
+ A name for the DataFrame, used in error messages.
+
+ Raises
+ ------
+ ValueError
+ If any lower bound is not strictly less than its corresponding upper bound.
+ """
invalid_index = pd.MultiIndex.from_frame(
df.query(f"{lwr} >= {upr}")[index]
).to_list()
if invalid_index:
- error_message = f"{name} has invalid interval with {len(invalid_index)} indices. \nLower age must be strictly less than upper age.\n"
+ error_message = f"{name} has invalid interval with {len(invalid_index)} indices. \nLower bound must be strictly less than upper bound.\n"
error_message += f"Index columns: ({', '.join(index)})\n"
if len(invalid_index) > 5:
error_message += "First 5 indices with invalid interval: \n"
@@ -75,17 +165,34 @@ def validate_interval(
def validate_noindexdiff(
df_ref: DataFrame, df: DataFrame, index: list[str], name: str
) -> None:
+ """
+ Validates that the indices of two DataFrames match.
+
+ Parameters
+ ----------
+ df_ref : pandas.DataFrame
+ The reference DataFrame containing the expected indices.
+ df : pandas.DataFrame
+ The DataFrame to validate against the reference.
+ index : list of str
+ A list of column names to be used as the index for comparison.
+ name : str
+ A name for the validation context, used in error messages.
+
+ Raises
+ ------
+ ValueError
+ If there are indices in the reference DataFrame that are missing in the DataFrame to validate.
+ """
index_ref = pd.MultiIndex.from_frame(df_ref[index])
- index = pd.MultiIndex.from_frame(df[index])
- missing_index = index_ref.difference(index).to_list()
+ index_to_check = pd.MultiIndex.from_frame(df[index])
+ missing_index = index_ref.difference(index_to_check).to_list()
if missing_index:
- error_message = (
- f"Missing {name} info for {len(missing_index)} indices \n"
- )
- error_message += f"Index columns: ({', '.join(index.names)})\n"
+ error_message = f"Missing {name} info for {len(missing_index)} indices \n"
+ error_message += f"Index columns: ({', '.join(index_ref.names)})\n"
if len(missing_index) > 5:
- error_message += "First 5: \n"
+ error_message += "First 5 missing indices: \n"
error_message += ", \n".join(str(idx) for idx in missing_index[:5])
error_message += "\n"
raise ValueError(error_message)
@@ -100,16 +207,36 @@ def validate_pat_coverage(
index: list[str],
name: str,
) -> None:
- """Validation checks for incomplete age pattern
- * pattern age intervals do not overlap or have gaps
- * smallest pattern interval doesn't cover the left end point of data
- * largest pattern interval doesn't cover the right end point of data
"""
- # sort dataframe
+ Validates that the pattern intervals cover the data intervals completely without gaps or overlaps.
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ The DataFrame containing both data intervals and pattern intervals.
+ lwr : str
+ The name of the column representing the data's lower bound.
+ upr : str
+ The name of the column representing the data's upper bound.
+ pat_lwr : str
+ The name of the column representing the pattern's lower bound.
+ pat_upr : str
+ The name of the column representing the pattern's upper bound.
+ index : list of str
+ A list of column names to group by when validating intervals.
+ name : str
+ A name for the DataFrame or validation context, used in error messages.
+
+ Raises
+ ------
+ ValueError
+ If the pattern intervals have gaps or overlaps, or if they do not fully cover the data intervals.
+ """
+ # Sort dataframe
df = df.sort_values(index + [lwr, upr, pat_lwr, pat_upr], ignore_index=True)
df_group = df.groupby(index)
- # check overlap or gap in pattern
+ # Check overlap or gap in pattern
shifted_pat_upr = df_group[pat_upr].shift(1)
connect_index = shifted_pat_upr.notnull()
connected = np.allclose(
@@ -122,7 +249,7 @@ def validate_pat_coverage(
"bounds across categories."
)
- # check coverage of head and tail
+ # Check coverage of head and tail
head_covered = df_group.first().eval(f"{lwr} >= {pat_lwr}").all()
tail_covered = df_group.last().eval(f"{upr} <= {pat_upr}").all()
@@ -134,31 +261,61 @@ def validate_pat_coverage(
def validate_realnumber(df: DataFrame, columns: list[str], name: str) -> None:
"""
- Validates that observation values in columns are real numbers and non-zero.
+ Validates that specified columns contain real numbers, are non-zero, and are not NaN or Inf.
Parameters
----------
- df : DataFrame
+ df : pandas.DataFrame
The DataFrame containing the data to validate.
columns : list of str
A list of column names to validate within the DataFrame.
name : str
- A string representing the name of the data or dataset
- (used for constructing error messages).
+ A name for the DataFrame, used in error messages.
Raises
------
ValueError
- If any column contains values that are not real numbers or are zero.
+ If any column contains values that are not real numbers, are zero, or are NaN/Inf.
"""
- # Check for non-real or zero values in the specified columns
- invalid = [
- col
- for col in columns
- if not df[col]
- .apply(lambda x: isinstance(x, (int, float)) and x != 0)
- .all()
- ]
+ # Check for non-real, zero, NaN, or Inf values in the specified columns
+ invalid = []
+ for col in columns:
+ if (
+ not df[col]
+ .apply(
+ lambda x: isinstance(x, (int, float))
+ and x != 0
+ and pd.notna(x)
+ and np.isfinite(x)
+ )
+ .all()
+ ):
+ invalid.append(col)
if invalid:
raise ValueError(f"{name} has non-real or zero values in: {invalid}")
+
+
+def validate_set_uniqueness(df: DataFrame, column: str, name: str) -> None:
+ """
+ Validates that each list in the specified column contains unique elements.
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ The DataFrame containing the data to validate.
+ column : str
+ The name of the column containing lists to validate.
+ name : str
+ A name for the DataFrame or validation context, used in error messages.
+
+ Raises
+ ------
+ ValueError
+ If any list in the specified column contains duplicate elements.
+ """
+ invalid_rows = df[df[column].apply(lambda x: len(x) != len(set(x)))]
+ if not invalid_rows.empty:
+ error_message = f"{name} has rows in column '{column}' where list elements are not unique.\n"
+ error_message += f"Indices of problematic rows: {invalid_rows.index.tolist()}\n"
+ raise ValueError(error_message)
diff --git a/tests/test_cat_splitter.py b/tests/test_cat_splitter.py
new file mode 100644
index 0000000..a2228f5
--- /dev/null
+++ b/tests/test_cat_splitter.py
@@ -0,0 +1,245 @@
+import pytest
+import pandas as pd
+from pydisagg.ihme.splitter import (
+ CatSplitter,
+ CatDataConfig,
+ CatPatternConfig,
+ CatPopulationConfig,
+)
+
+# Step 1: Setup Fixtures
+
+
+@pytest.fixture
+def cat_data_config():
+ return CatDataConfig(
+ index=[
+ "study_id",
+ "year_id",
+ "location_id",
+ "sub_category",
+ ], # Include 'sub_category' in index
+ cat_group="sub_category",
+ val="val",
+ val_sd="val_sd",
+ )
+
+
+@pytest.fixture
+def cat_pattern_config():
+ return CatPatternConfig(
+ by=["year_id", "location_id"],
+ cat="sub_category",
+ val="pattern_val",
+ val_sd="pattern_val_sd",
+ )
+
+
+@pytest.fixture
+def cat_population_config():
+ return CatPopulationConfig(
+ index=["year_id", "location_id", "sub_category"],
+ val="population",
+ )
+
+
+@pytest.fixture
+def valid_data():
+ return pd.DataFrame(
+ {
+ "study_id": [1, 2, 3],
+ "year_id": [2000, 2000, 2001],
+ "location_id": [10, 20, 10],
+ "sub_category": [
+ ["A1", "A2"], # List of sub_categories
+ ["B1", "B2"],
+ ["C1", "C2"],
+ ],
+ "val": [100, 200, 150],
+ "val_sd": [10, 20, 15],
+ }
+ )
+
+
+@pytest.fixture
+def valid_pattern():
+ return pd.DataFrame(
+ {
+ "year_id": [2000, 2000, 2000, 2000, 2001, 2001],
+ "location_id": [10, 10, 20, 20, 10, 10],
+ "sub_category": ["A1", "A2", "B1", "B2", "C1", "C2"],
+ "pattern_val": [0.6, 0.4, 0.7, 0.3, 0.55, 0.45],
+ "pattern_val_sd": [0.06, 0.04, 0.07, 0.03, 0.055, 0.045],
+ }
+ )
+
+
+@pytest.fixture
+def valid_population():
+ return pd.DataFrame(
+ {
+ "year_id": [2000, 2000, 2000, 2000, 2001, 2001],
+ "location_id": [10, 10, 20, 20, 10, 10],
+ "sub_category": ["A1", "A2", "B1", "B2", "C1", "C2"],
+ "population": [5000, 3000, 7000, 3000, 5500, 4500],
+ }
+ )
+
+
+@pytest.fixture
+def cat_splitter(cat_data_config, cat_pattern_config, cat_population_config):
+ return CatSplitter(
+ data=cat_data_config,
+ pattern=cat_pattern_config,
+ population=cat_population_config,
+ )
+
+
+# Step 2: Write Tests for parse_data
+
+
+def test_parse_data_duplicated_index(cat_splitter, valid_data):
+ """Test parse_data raises an error on duplicated index."""
+ duplicated_data = pd.concat([valid_data, valid_data])
+ with pytest.raises(ValueError, match="has duplicated index"):
+ cat_splitter.parse_data(duplicated_data, positive_strict=True)
+
+
+def test_parse_data_valid(cat_splitter, valid_data):
+ """Test that parse_data works correctly on valid data."""
+ parsed_data = cat_splitter.parse_data(valid_data, positive_strict=True)
+ assert not parsed_data.empty
+ assert "val" in parsed_data.columns
+ assert "val_sd" in parsed_data.columns
+
+
+# Step 3: Write Tests for parse_pattern
+
+
+def test_parse_pattern_valid(cat_splitter, valid_data, valid_pattern):
+ """Test that parse_pattern works correctly on valid data."""
+ parsed_data = cat_splitter.parse_data(valid_data, positive_strict=True)
+ parsed_pattern = cat_splitter.parse_pattern(
+ parsed_data, valid_pattern, model="rate"
+ )
+ assert not parsed_pattern.empty
+ # The pattern columns are renamed with prefix 'cat_pat_'
+ assert (
+ f"{cat_splitter.pattern.prefix}{cat_splitter.pattern.val}"
+ in parsed_pattern.columns
+ )
+ assert (
+ f"{cat_splitter.pattern.prefix}{cat_splitter.pattern.val_sd}"
+ in parsed_pattern.columns
+ )
+
+
+# Step 4: Write Tests for parse_population
+
+
+def test_parse_population_missing_columns(
+ cat_splitter, valid_data, valid_pattern, valid_population
+):
+ """Test parse_population raises an error when population columns are missing."""
+ invalid_population = valid_population.drop(columns=["population"])
+ parsed_data = cat_splitter.parse_data(valid_data, positive_strict=True)
+ parsed_pattern = cat_splitter.parse_pattern(
+ parsed_data, valid_pattern, model="rate"
+ )
+ with pytest.raises(KeyError, match="has missing columns"):
+ cat_splitter.parse_population(parsed_pattern, invalid_population)
+
+
+def test_parse_population_with_nan(
+ cat_splitter, valid_data, valid_pattern, valid_population
+):
+ """Test parse_population raises an error when there are NaN values."""
+ invalid_population = valid_population.copy()
+ invalid_population.loc[0, "population"] = None
+ parsed_data = cat_splitter.parse_data(valid_data, positive_strict=True)
+ parsed_pattern = cat_splitter.parse_pattern(
+ parsed_data, valid_pattern, model="rate"
+ )
+ with pytest.raises(ValueError, match="has NaN values"):
+ cat_splitter.parse_population(parsed_pattern, invalid_population)
+
+
+def test_parse_population_valid(
+ cat_splitter, valid_data, valid_pattern, valid_population
+):
+ """Test that parse_population works correctly on valid data."""
+ parsed_data = cat_splitter.parse_data(valid_data, positive_strict=True)
+ parsed_pattern = cat_splitter.parse_pattern(
+ parsed_data, valid_pattern, model="rate"
+ )
+ parsed_population = cat_splitter.parse_population(
+ parsed_pattern, valid_population
+ )
+ assert not parsed_population.empty
+ # The population column is renamed with prefix 'cat_pop_'
+ pop_col = f"{cat_splitter.population.prefix}{cat_splitter.population.val}"
+ assert pop_col in parsed_population.columns
+
+
+# Step 5: Write Tests for the split method
+
+
+def test_split_valid(cat_splitter, valid_data, valid_pattern, valid_population):
+ """Test that the split method works correctly on valid data."""
+ result = cat_splitter.split(
+ data=valid_data,
+ pattern=valid_pattern,
+ population=valid_population,
+ model="rate",
+ output_type="rate",
+ )
+ assert not result.empty
+ assert "split_result" in result.columns
+ assert "split_result_se" in result.columns
+
+
+def test_split_with_invalid_output_type(
+ cat_splitter, valid_data, valid_pattern, valid_population
+):
+ """Test that the split method raises an error with an invalid output_type."""
+ with pytest.raises(ValueError, match="Invalid output_type"):
+ cat_splitter.split(
+ data=valid_data,
+ pattern=valid_pattern,
+ population=valid_population,
+ model="rate",
+ output_type="invalid_output",
+ )
+
+
+def test_split_with_missing_population(cat_splitter, valid_data, valid_pattern):
+ """Test that the split method raises an error when population data is missing."""
+ with pytest.raises(
+ KeyError, match="Parsing Population has missing columns"
+ ):
+ cat_splitter.split(
+ data=valid_data,
+ pattern=valid_pattern,
+ population=pd.DataFrame(), # Empty population data
+ model="rate",
+ output_type="rate",
+ )
+
+
+def test_split_with_non_matching_categories(
+ cat_splitter, valid_data, valid_pattern, valid_population
+):
+ """Test that the split method raises an error when categories don't match."""
+ invalid_population = valid_population.copy()
+ invalid_population["sub_category"] = ["X1", "X2", "X1", "X2", "X1", "X2"]
+ with pytest.raises(
+ ValueError,
+ match="After merging with population, there were NaN values created",
+ ):
+ cat_splitter.split(
+ data=valid_data,
+ pattern=valid_pattern,
+ population=invalid_population,
+ model="rate",
+ output_type="rate",
+ )
diff --git a/tests/test_validator.py b/tests/test_validator.py
index 7c528f7..6a3464d 100644
--- a/tests/test_validator.py
+++ b/tests/test_validator.py
@@ -10,9 +10,12 @@
validate_nonan,
validate_pat_coverage,
validate_positive,
+ validate_set_uniqueness,
+ validate_realnumber,
)
+# Test functions
@pytest.fixture
def data():
np.random.seed(123)
@@ -65,6 +68,7 @@ def population():
return population
+# Tests for validate_columns
def test_validate_columns_missing(population):
with pytest.raises(KeyError):
validate_columns(
@@ -74,6 +78,16 @@ def test_validate_columns_missing(population):
)
+def test_validate_columns_no_missing(population):
+ # All columns are present; should pass
+ validate_columns(
+ population,
+ ["sex_id", "location_id", "age_group_id", "year_id", "population"],
+ "population",
+ )
+
+
+# Tests for validate_index
def test_validate_index_missing(population):
with pytest.raises(ValueError):
validate_index(
@@ -83,11 +97,27 @@ def test_validate_index_missing(population):
)
+def test_validate_index_no_duplicates(population):
+ # Ensure DataFrame has no duplicate indices; should pass
+ validate_index(
+ population,
+ ["sex_id", "location_id", "age_group_id", "year_id"],
+ "population",
+ )
+
+
+# Tests for validate_nonan
def test_validate_nonan(population):
with pytest.raises(ValueError):
validate_nonan(population.assign(population=np.nan), "population")
+def test_validate_nonan_no_nan(population):
+ # No NaN values; should pass
+ validate_nonan(population, "population")
+
+
+# Tests for validate_positive
def test_validate_positive_strict(population):
with pytest.raises(ValueError):
validate_positive(
@@ -108,6 +138,12 @@ def test_validate_positive_not_strict(population):
)
+def test_validate_positive_no_error(population):
+ validate_positive(population, ["population"], "population", strict=True)
+ validate_positive(population, ["population"], "population", strict=False)
+
+
+# Tests for validate_interval
def test_validate_interval_lower_equal_upper(data):
with pytest.raises(ValueError):
validate_interval(
@@ -130,11 +166,7 @@ def test_validate_interval_positive(data):
validate_interval(data, "age_start", "age_end", ["uid"], "data")
-def test_validate_positive_no_error(population):
- validate_positive(population, ["population"], "population", strict=True)
- validate_positive(population, ["population"], "population", strict=False)
-
-
+# Tests for validate_noindexdiff
@pytest.fixture
def merged_data_pattern(data, pattern):
return pd.merge(
@@ -149,7 +181,7 @@ def test_validate_noindexdiff_merged_positive(merged_data_pattern, population):
# Positive test case: no index difference
validate_noindexdiff(
population,
- merged_data_pattern,
+ merged_data_pattern.dropna(subset=["sex_id", "location_id"]),
["sex_id", "location_id"],
"merged_data_pattern",
)
@@ -173,6 +205,7 @@ def test_validate_noindexdiff_merged_negative(data, pattern):
)
+# Tests for validate_pat_coverage
@pytest.mark.parametrize(
"bad_data_with_pattern",
[
@@ -221,3 +254,81 @@ def test_validate_pat_coverage_failure(bad_data_with_pattern):
["group_id"],
"pattern",
)
+
+
+# Tests for validate_realnumber
+def test_validate_realnumber_positive():
+ df = pd.DataFrame({"col1": [1, 2.5, -3.5, 4.2], "col2": [5.1, 6, 7, 8]})
+ # Should pass without exceptions
+ validate_realnumber(df, ["col1", "col2"], "df")
+
+
+def test_validate_realnumber_zero():
+ df = pd.DataFrame({"col1": [1, 2, 0, 4], "col2": [5, 6, 7, 8]})
+ with pytest.raises(
+ ValueError, match="df has non-real or zero values in: \\['col1'\\]"
+ ):
+ validate_realnumber(df, ["col1"], "df")
+
+
+def test_validate_realnumber_nan():
+ df = pd.DataFrame({"col1": [1, 2, 3, np.nan], "col2": [5, 6, 7, 8]})
+ with pytest.raises(
+ ValueError, match="df has non-real or zero values in: \\['col1'\\]"
+ ):
+ validate_realnumber(df, ["col1"], "df")
+
+
+def test_validate_realnumber_non_numeric():
+ df = pd.DataFrame({"col1": [1, 2, 3, "a"], "col2": [5, 6, 7, 8]})
+ with pytest.raises(
+ ValueError, match="df has non-real or zero values in: \\['col1'\\]"
+ ):
+ validate_realnumber(df, ["col1"], "df")
+
+
+def test_validate_realnumber_infinite():
+ df = pd.DataFrame({"col1": [1, 2, 3, np.inf], "col2": [5, 6, 7, 8]})
+ # np.inf is not a finite real number
+ with pytest.raises(
+ ValueError, match="df has non-real or zero values in: \\['col1'\\]"
+ ):
+ validate_realnumber(df, ["col1"], "df")
+
+
+# Tests for validate_set_uniqueness
+def test_validate_set_uniqueness_positive():
+ df = pd.DataFrame(
+ {"col1": [[1, 2, 3], ["a", "b", "c"], [True, False], [1.1, 2.2, 3.3]]}
+ )
+ # Should pass without exceptions
+ validate_set_uniqueness(df, "col1", "df")
+
+
+def test_validate_set_uniqueness_negative():
+ df = pd.DataFrame(
+ {"col1": [[1, 2, 2], ["a", "b", "a"], [True, False], [1.1, 2.2, 1.1]]}
+ )
+ with pytest.raises(
+ ValueError,
+ match="df has rows in column 'col1' where list elements are not unique.",
+ ):
+ validate_set_uniqueness(df, "col1", "df")
+
+
+def test_validate_set_uniqueness_empty_lists():
+ df = pd.DataFrame({"col1": [[], [], []]})
+ # Should pass; empty lists have no duplicates
+ validate_set_uniqueness(df, "col1", "df")
+
+
+def test_validate_set_uniqueness_single_element_lists():
+ df = pd.DataFrame({"col1": [[1], ["a"], [True]]})
+ # Should pass; single-element lists can't have duplicates
+ validate_set_uniqueness(df, "col1", "df")
+
+
+def test_validate_set_uniqueness_mixed_types_with_duplicates():
+ df = pd.DataFrame({"col1": [[1, "1", 1.0], [True, 1, 1.0], [2, 2, 2]]})
+ with pytest.raises(ValueError):
+ validate_set_uniqueness(df, "col1", "df")