-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.cpp
61 lines (37 loc) · 1.34 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include "DRandomForest.h"
const unsigned int MAX_SAMPLES = 256;
const unsigned int DTREE_COUNT = 100;
const unsigned int MAX_DEPTH = 9;
const unsigned int MIN_SPLIT = 1;
const unsigned int MIN_LEAF = 1;
const double IMPURITY_THRESHOLD = 0.01;
const double OUT_OF_BAG_ERROR_THRESHOLD = 0.0025;
const bool BOOTSTRAPPING_ALLOWED = true;
const bool REGRESSION = false;
const bool MULTITHREAD = true;
const ImpurityFunctor IMPURITY_FUNCTION = calculateShannonEntropy;
const FeatureFunctor FEATURE_FUNCTION = squareRoot;
int main()
{
//object initialization
DData trainingData("fruit training set 2.csv", MAX_SAMPLES);
DData testData("fruit test set.csv", MAX_SAMPLES);
DRandomForest randomForest(
DTREE_COUNT, MAX_DEPTH, MIN_SPLIT, MIN_LEAF,
IMPURITY_THRESHOLD,
BOOTSTRAPPING_ALLOWED, REGRESSION, MULTITHREAD,
IMPURITY_FUNCTION,
FEATURE_FUNCTION
);
//object interface use
randomForest.fit(trainingData);
/*
while (randomForest.getOutOfBagError() > OUT_OF_BAG_ERROR_THRESHOLD)
{
randomForest.reset();
randomForest.fit(trainingData);
}
*/
randomForest.classifyBatch(testData);
return 0;
}