Skip to content

Commit

Permalink
Add factory methods for summary, proposal, kernel functions
Browse files Browse the repository at this point in the history
  • Loading branch information
apulsipher committed Oct 22, 2024
1 parent f02d864 commit 7cf7281
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 16 deletions.
12 changes: 12 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ set_observed_data_cpp <- function(lfmcmc, observed_data_) {
.Call(`_epiworldR_set_observed_data_cpp`, lfmcmc, observed_data_)
}

create_LFMCMCProposalFun_cpp <- function(fun) {
.Call(`_epiworldR_create_LFMCMCProposalFun_cpp`, fun)
}

set_proposal_fun_cpp <- function(lfmcmc, fun) {
.Call(`_epiworldR_set_proposal_fun_cpp`, lfmcmc, fun)
}
Expand All @@ -244,10 +248,18 @@ set_simulation_fun_cpp <- function(lfmcmc, fun) {
.Call(`_epiworldR_set_simulation_fun_cpp`, lfmcmc, fun)
}

create_LFMCMCSummaryFun_cpp <- function(fun) {
.Call(`_epiworldR_create_LFMCMCSummaryFun_cpp`, fun)
}

set_summary_fun_cpp <- function(lfmcmc, fun) {
.Call(`_epiworldR_set_summary_fun_cpp`, lfmcmc, fun)
}

create_LFMCMCKernelFun_cpp <- function(fun) {
.Call(`_epiworldR_create_LFMCMCKernelFun_cpp`, fun)
}

set_kernel_fun_cpp <- function(lfmcmc, fun) {
.Call(`_epiworldR_set_kernel_fun_cpp`, lfmcmc, fun)
}
Expand Down
24 changes: 24 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,13 @@ extern "C" SEXP _epiworldR_set_observed_data_cpp(SEXP lfmcmc, SEXP observed_data
END_CPP11
}
// lfmcmc.cpp
SEXP create_LFMCMCProposalFun_cpp(cpp11::function fun);
extern "C" SEXP _epiworldR_create_LFMCMCProposalFun_cpp(SEXP fun) {
BEGIN_CPP11
return cpp11::as_sexp(create_LFMCMCProposalFun_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::function>>(fun)));
END_CPP11
}
// lfmcmc.cpp
SEXP set_proposal_fun_cpp(SEXP lfmcmc, SEXP fun);
extern "C" SEXP _epiworldR_set_proposal_fun_cpp(SEXP lfmcmc, SEXP fun) {
BEGIN_CPP11
Expand All @@ -433,13 +440,27 @@ extern "C" SEXP _epiworldR_set_simulation_fun_cpp(SEXP lfmcmc, SEXP fun) {
END_CPP11
}
// lfmcmc.cpp
SEXP create_LFMCMCSummaryFun_cpp(cpp11::function fun);
extern "C" SEXP _epiworldR_create_LFMCMCSummaryFun_cpp(SEXP fun) {
BEGIN_CPP11
return cpp11::as_sexp(create_LFMCMCSummaryFun_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::function>>(fun)));
END_CPP11
}
// lfmcmc.cpp
SEXP set_summary_fun_cpp(SEXP lfmcmc, SEXP fun);
extern "C" SEXP _epiworldR_set_summary_fun_cpp(SEXP lfmcmc, SEXP fun) {
BEGIN_CPP11
return cpp11::as_sexp(set_summary_fun_cpp(cpp11::as_cpp<cpp11::decay_t<SEXP>>(lfmcmc), cpp11::as_cpp<cpp11::decay_t<SEXP>>(fun)));
END_CPP11
}
// lfmcmc.cpp
SEXP create_LFMCMCKernelFun_cpp(cpp11::function fun);
extern "C" SEXP _epiworldR_create_LFMCMCKernelFun_cpp(SEXP fun) {
BEGIN_CPP11
return cpp11::as_sexp(create_LFMCMCKernelFun_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::function>>(fun)));
END_CPP11
}
// lfmcmc.cpp
SEXP set_kernel_fun_cpp(SEXP lfmcmc, SEXP fun);
extern "C" SEXP _epiworldR_set_kernel_fun_cpp(SEXP lfmcmc, SEXP fun) {
BEGIN_CPP11
Expand Down Expand Up @@ -1021,7 +1042,10 @@ static const R_CallMethodDef CallEntries[] = {
{"_epiworldR_agents_smallworld_cpp", (DL_FUNC) &_epiworldR_agents_smallworld_cpp, 5},
{"_epiworldR_change_state_cpp", (DL_FUNC) &_epiworldR_change_state_cpp, 4},
{"_epiworldR_clone_model_cpp", (DL_FUNC) &_epiworldR_clone_model_cpp, 1},
{"_epiworldR_create_LFMCMCKernelFun_cpp", (DL_FUNC) &_epiworldR_create_LFMCMCKernelFun_cpp, 1},
{"_epiworldR_create_LFMCMCProposalFun_cpp", (DL_FUNC) &_epiworldR_create_LFMCMCProposalFun_cpp, 1},
{"_epiworldR_create_LFMCMCSimFun_cpp", (DL_FUNC) &_epiworldR_create_LFMCMCSimFun_cpp, 1},
{"_epiworldR_create_LFMCMCSummaryFun_cpp", (DL_FUNC) &_epiworldR_create_LFMCMCSummaryFun_cpp, 1},
{"_epiworldR_distribute_entity_randomly_cpp", (DL_FUNC) &_epiworldR_distribute_entity_randomly_cpp, 3},
{"_epiworldR_distribute_entity_to_set_cpp", (DL_FUNC) &_epiworldR_distribute_entity_to_set_cpp, 1},
{"_epiworldR_distribute_tool_randomly_cpp", (DL_FUNC) &_epiworldR_distribute_tool_randomly_cpp, 2},
Expand Down
64 changes: 58 additions & 6 deletions src/lfmcmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,44 @@ SEXP set_observed_data_cpp(
return lfmcmc;
}

// LFMCMC Proposal Function
[[cpp11::register]]
SEXP create_LFMCMCProposalFun_cpp(
cpp11::function fun
) {

LFMCMCProposalFun<TData_default> fun_call = [fun](std::vector< epiworld_double >& params_now,const std::vector< epiworld_double >& params_prev, LFMCMC<TData_default>* model) -> void {
WrapLFMCMC(lfmcmc_ptr)(model);
fun(params_now, params_prev, lfmcmc_ptr);
return;
};

return cpp11::external_pointer<LFMCMCProposalFun<TData_default>>(
new LFMCMCProposalFun<TData_default>(fun_call)
);
}

[[cpp11::register]]
SEXP set_proposal_fun_cpp(
SEXP lfmcmc,
SEXP fun
) {
cpp11::external_pointer<LFMCMCProposalFun<TData_default>> fun_ptr(fun);
cpp11::external_pointer<LFMCMCProposalFun<TData_default>> fun_ptr = create_LFMCMCProposalFun_cpp(fun);
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
lfmcmc_ptr->set_proposal_fun(*fun_ptr);
return lfmcmc;
}

// LFMCMC Simulation Function
[[cpp11::register]]
SEXP create_LFMCMCSimFun_cpp(
cpp11::function fun
) {

LFMCMCSimFun<TData_default> fun_call = [fun](const std::vector<epiworld_double>& params, LFMCMC<TData_default>* model) -> TData_default {
WrapLFMCMC(lfmcmc_ptr)(model);
SEXP res = fun(params, lfmcmc_ptr);
cpp11::external_pointer<TData_default> res_vec(res);
return *res_vec;
cpp11::external_pointer<TData_default> res(fun(params, lfmcmc_ptr));
return *res;
};

return cpp11::external_pointer<LFMCMCSimFun<TData_default>>(
Expand All @@ -85,23 +102,58 @@ SEXP set_simulation_fun_cpp(
return lfmcmc;
}

// LFMCMC Summary Function
[[cpp11::register]]
SEXP create_LFMCMCSummaryFun_cpp(
cpp11::function fun
) {

LFMCMCSummaryFun<TData_default> fun_call = [fun](std::vector< epiworld_double >& res, const TData_default& dat, LFMCMC<TData_default>* model) -> void {
WrapLFMCMC(lfmcmc_ptr)(model);
fun(res, dat, lfmcmc_ptr);
return;
};

return cpp11::external_pointer<LFMCMCSummaryFun<TData_default>>(
new LFMCMCSummaryFun<TData_default>(fun_call)
);
}

[[cpp11::register]]
SEXP set_summary_fun_cpp(
SEXP lfmcmc,
SEXP fun
) {
cpp11::external_pointer<LFMCMCSummaryFun<TData_default>> fun_ptr(fun);
cpp11::external_pointer<LFMCMCSummaryFun<TData_default>> fun_ptr = create_LFMCMCSummaryFun_cpp(fun);
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
lfmcmc_ptr->set_summary_fun(*fun_ptr);
return lfmcmc;
}

// LFMCMC Kernel Function
// TODO: clean up these really long lines
[[cpp11::register]]
SEXP create_LFMCMCKernelFun_cpp(
cpp11::function fun
) {

LFMCMCKernelFun<TData_default> fun_call = [fun](const std::vector< epiworld_double >& stats_now, const std::vector< epiworld_double >& stats_obs, epiworld_double epsilon, LFMCMC<TData_default>* model) -> epiworld_double {
WrapLFMCMC(lfmcmc_ptr)(model);
cpp11::external_pointer<epiworld_double> res(fun(stats_now, stats_obs, epsilon, lfmcmc_ptr));
return *res;
};

return cpp11::external_pointer<LFMCMCKernelFun<TData_default>>(
new LFMCMCKernelFun<TData_default>(fun_call)
);
}

[[cpp11::register]]
SEXP set_kernel_fun_cpp(
SEXP lfmcmc,
SEXP fun
) {
cpp11::external_pointer<LFMCMCKernelFun<TData_default>> fun_ptr(fun);
cpp11::external_pointer<LFMCMCKernelFun<TData_default>> fun_ptr = create_LFMCMCKernelFun_cpp(fun);
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
lfmcmc_ptr->set_kernel_fun(*fun_ptr);
return lfmcmc;
Expand Down
19 changes: 9 additions & 10 deletions vignettes/likelihood-free-mcmc.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,21 @@ simfun <- function(params, m) {
}
# TODO: Define Summary Function
sumfun <- function(res, dat, m) {
# if (res.size() == 0)
# if (length(res) == 0)
# res.resize(data.size())
# for (i in dat.size())
# res[i] = static_cast<float>(dat[i])
# return
return
}
# TODO: Define Proposal Function
propfun <- function(scale, lb, ub) {
propfun <- function(params_now, params_prev, m) {
return
}
# TODO: Define Kernel Function
kernfun <- function() {
return(1.0)
}
# Set initial parameters
Expand All @@ -101,11 +101,10 @@ par0 <- c(.5, .5)
```{r lfmcmc-run}
# TODO: make these work
lfmcmc_model <- LFMCMC() |>
set_simulation_fun(simfun)
# set_simulation_fun(lfmcmc_model, simfun)
# set_summary_fun(sumfun) |>
# set_proposal_fun(propfun) |>
# set_kernel_fun(kernfun) |>
set_simulation_fun(simfun) |>
set_summary_fun(sumfun) |>
set_proposal_fun(propfun) |>
set_kernel_fun(kernfun)
# set_observed_data(obs_dat) |>
# run_lfmcmc(par0, 2000, 1)
Expand Down

0 comments on commit 7cf7281

Please sign in to comment.