Skip to content

Commit

Permalink
debugged and working c++ implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lenarddome committed Jul 8, 2022
1 parent 35f6e7a commit c75d02f
Showing 1 changed file with 35 additions and 63 deletions.
98 changes: 35 additions & 63 deletions src/pspGlobal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ cube OrdinalCompare(cube discovered, cube predicted) {
for (int y = 0; y < discovered.n_slices; y++) {
mat base = discovered.slice(y);
umat result = (base == current);
index(x, y) = all(any(result == 0));
uvec comparisons = result(trimatu_ind(size(result), 1));
index(x, y) = any(comparisons == 0);
}
if (all(index.row(x) == 1)) {
cube update = join_slices(drawer, current);
Expand All @@ -63,7 +64,8 @@ mat LastEvaluatedParameters(cube discovered, cube predicted, mat jumping, mat ce
for (uword y = 0; y < discovered.n_slices; y++) {
mat base = discovered.slice(y);
umat result = (base == current);
index(x, y) = all(any(result == 0));
uvec comparisons = result(trimatu_ind(size(result), 1));
index(x, y) = any(comparisons == 0);
}
if (all(index.row(x) == 1)) {
// if there is a new region, appeng params to centers
Expand All @@ -87,7 +89,8 @@ rowvec CountOrdinal(cube updated_ordinal, cube predicted, rowvec counts) {
for (int y = 0; y < predicted.n_slices; y++) {
mat base = predicted.slice(y);
umat result = (base == current);
index(x, y) = all(any(result == 0));
uvec comparisons = result(trimatu_ind(size(result), 1));
index(x, y) = any(comparisons == 0);
}
if (!all(index.row(x) == 1)) {
new_counts[x] += 1;
Expand All @@ -106,37 +109,16 @@ uvec MatchJumpDists(cube updated_ordinal, cube predicted) {
for (uword y = 0; y < predicted.n_slices; y++) {
mat base = predicted.slice(y);
umat result = (base == current);
index(x, y) = all(any(result == 0));
uvec comparisons = result(trimatu_ind(size(result), 1));
index(x, y) = any(comparisons == 0);
}
}
matches = find(sum(index, 1) < predicted.n_slices);
return(matches);
}

// writes rows to csv file
void WriteFile(int iteration, mat evaluation, int dimension, uvec matches,
std::string path_to_file) {
// open file stream connection
std::ofstream outFile(path_to_file.c_str());
int rows = evaluation.n_rows;
int columns = dimension + 1;
for (uword i = 0; i < rows; i++) {
outFile << i + ",";
for (uword k = 0; k < columns; k++) {
outFile << evaluation[i, k] + ",";
}
outFile << matches[i] + ",\n";
}
// close file connection
outFile.close();
}

// [[Rcpp::export]]
List pspGlobal(std::string fn, List control, std::string filename,
std::string path = ".", bool quiet = false) {
// call the ordinal function used for evaluation parameters
Environment env = Environment::global_env();
Function model = env[fn];
List pspGlobal(Function model, List control, bool quiet = false) {

// setup environment
bool parameter_filled = false;
Expand All @@ -154,61 +136,50 @@ List pspGlobal(std::string fn, List control, std::string filename,
}

if (population == datum::inf && max_iteration == datum::inf) {
stop("A resonable threshold must be set by either adjusting iteration or population.")
stop("A resonable threshold must be set by either adjusting iteration or population.");
}

rowvec radius = as<colvec>(control["radius"]);
rowvec init = as<colvec>(control["init"]);
rowvec radius = as<rowvec>(control["radius"]);
rowvec init = as<rowvec>(control["init"]);

colvec lower = as<colvec>(control["lower"]);
colvec upper = as<colvec>(control["upper"]);
int dimension = init.n_elem;
int dimensions = init.n_elem;
// do some basic error checks
if (dimension != lower.n_elem || dimension != upper.n_elem {
if (dimensions != lower.n_elem || dimensions != upper.n_elem) {
stop("init, lower and upper must have the same length.");
}
mat jumping_distribution(init); // the parameter sets to be evaluated
mat last_eval(init); // the last evaluated parameter set for each ordinal pattern
rowvec counts; // keeps track of the population of ordinal regions
cube ordinal; // stores all evaluations of fn on jumping_distribution
cube storage; // stores all unique ordinal patterns

CharacterVector names = as<CharacterVector>(control["param_names"]);

List out;

// setup file and create headers
std::ofstream outFile(path + filename);
outFile << "iteration,";
for (uword i = 0; i < dimensions; i++) {
outFile << names[i] + ",";
}
outFile << names + ",pattern,\n";
// close file connection
outFile.close();

// evaluate first parameter set
mat output = model(init);
int stimuli = output.n_rows;
NumericMatrix teatime = model(init);
const mat& evaluate = as<mat>(teatime);
int stimuli = evaluate.n_rows;
mat last_eval = init; // last evaluated parameters
// add output to storage
storage = join_slices(storage, output)
delete[] output;
storage = join_slices(storage, evaluate);
// delete[] output;

// run parameter space partitioning until parameter is filled
while (parameter_filled) {
while (!parameter_filled) {
// update iteration
iteration += 1;

// generate new jumping distributions from ordinal patterns with counts < population
jumping_distribution = HyperPoints(last_eval.n_rows, dimensions, radius);
mat jumping_distribution = HyperPoints(last_eval.n_rows, dimensions, radius);
jumping_distribution = ClampParameters(jumping_distribution, lower, upper);
jumping_distribution += last_evaluated;
jumping_distribution.shed_rows(find(counts > pop));
jumping_distribution += last_eval;
jumping_distribution.shed_rows(find(counts > population));

cube ordinal(stimuli, stimuli, jumping_distribution.n_rows);
// evaluate jumping distributions
for (uword i = 0; i < jumping_distribution.n_rows; i++) {
mat evaluate = model(jumping_distribution.row(i));
NumericMatrix teatime = model(jumping_distribution.row(i));
const mat& evaluate = as<mat>(teatime);
ordinal.slice(i) = evaluate;
}
// compare ordinal patterns to stored ones and update list
Expand All @@ -217,23 +188,24 @@ List pspGlobal(std::string fn, List control, std::string filename,
// update counts of ordinal patterns
counts = CountOrdinal(storage, ordinal, counts);
// write data to disk
outFile << "\n";
// outFile << "\n";
rowvec vector_counts = vectorise(counts, 1);
irowvec print_counts = conv_to< irowvec >::from(vector_counts);

// print information about iteration
if (!quiet) {
Rcout << "Iteration:" << iteration << std::endl;
std::cout << "[" << iteration << "]: " << print_counts << std::endl;
}

// check if parameter_filled threshold is reached
if (iteration == threshold || all(counts > population)) {
parameter_filled = TRUE
if (iteration == max_iteration || all(counts > population)) {
parameter_filled = TRUE;
}
}


// compile output including ordinal patterns and their frequencies
out = Rcpp::List::create(
Rcpp::Named("ordinal_counts") = counts,
Rcpp::Named("ordinal_patterns") = storage);

// compile output including ordinal patterns and their frequencies
return(out)
return(out);
}

0 comments on commit c75d02f

Please sign in to comment.