Skip to content

Commit

Permalink
add tests for draws
Browse files Browse the repository at this point in the history
  • Loading branch information
ADucellierIHME committed Oct 16, 2024
1 parent bf49870 commit b08f59f
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
42 changes: 41 additions & 1 deletion src/raking/run_raking.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ def run_raking(
pass
# Check if matrix is definite positive
(sigma_yy, sigma_ss, sigma_ys) = check_covariance(sigma_yy, sigma_ss, sigma_ys)


# Compute the mean (if we have draws)
if cov_mat:
(df_obs, df_margins) = compute_mean(df_obs, df_margins, varnames, draws)

# Get the input variables for the raking
if dim == 1:
(y, s, q, l, h, A) = run_raking_1D(df_obs, df_margins, var_names, weights, lower, upper, rtol, atol)
Expand Down Expand Up @@ -452,3 +456,39 @@ def compute_covariance_3D(
sigma_ys = compute_covariance_obs_margins_3D(df_obs, df_margins_1, df_margins_2, df_margins_3, var_names, draws)
return (sigma_yy, sigma_ss, sigma_ys)

def compute_mean(
df_obs: pd.DataFrame,
df_margins: list,
varnames: list,
draws: str
) -> tuple[pd.DataFrame, list]:
"""Compute the means of the values over all the samples.
Parameters
----------
df_obs : pd.DataFrame
Observations data
df_margins : list of pd.DataFrame
list of data frames contatining the margins data
var_names : list of strings
Names of the variables over which we rake (e.g. cause, race, county)
draws: string
Name of the column that contains the samples.
Returns
-------
df_obs_mean : pd.DataFrame
Means of observations data
df_margins_mean : list of pd.DataFrame
list of data frames contatining the mans of the margins data
"""
columns = df_obs.columns.drop([draws, 'value']).to_list()
df_obs_mean = df_obs.groupby(columns).mean().reset_index().drop(columns=[draws])
df_margins_mean = []
for (df_margin, var_name) in zip(df_margins, var_names):
value_name = 'value_agg_over_' + var_name
columns = df_margin.columns.drop([draws, value_name]).to_list()
df_margin_mean = df_margin.groupby(columns).mean().reset_index().drop(columns=[draws])
df_margins_mean.append(df_margin_mean)
return (df_obs_mean, df_margins_mean)

37 changes: 37 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,40 @@ def test_run_raking_3D(example_3D):
cov_mat=False
)

def test_run_raking_1D_draws(example_1D_draws):
(df_obs, Dphi_y, Dphi_s, sigma) = run_raking(
dim=1,
df_obs=example_1D_draws.df_obs,
df_margins=[example_1D_draws.df_margin],
var_names=['var1'],
draws='draws',
cov_mat=True
)

def test_run_raking_2D_draws(example_2D_draws):
(df_obs, Dphi_y, Dphi_s, sigma) = run_raking(
dim=2,
df_obs=example_2D_draws.df_obs,
df_margins=[
example_2D_draws.df_margins_1,
example_2D_draws.df_margins_2
],
var_names=['var1', 'var2'],
draws='draws',
cov_mat=True
)

def test_run_raking_3D_draws(example_3D_draws):
(df_obs, Dphi_y, Dphi_s, sigma) = run_raking(
dim=3,
df_obs=example_3D_draws.df_obs,
df_margins=[
example_3D_draws.df_margins_1,
example_3D_draws.df_margins_2,
example_3D_draws.df_margins_3,
],
var_names=['var1', 'var2', 'var3'],
draws='draws',
cov_mat=True
)

0 comments on commit b08f59f

Please sign in to comment.