forked from wverbeke/PFNEvaluation
-
Notifications
You must be signed in to change notification settings - Fork 1
/
PFNReader.cc
127 lines (100 loc) · 4.25 KB
/
PFNReader.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include "PFNReader.h"
//include c++ library classes
#include <stdexcept>
#include <iostream>
unsigned PFNReader::instanceCounter = 0;
boost::python::object PFNReader::pythonModule = boost::python::object();
PFNReader::PFNReader( const std::string& model_file_name, const unsigned highlevelInputShape, const std::pair<unsigned, unsigned>& pfnInputShape ):
pfnShape( pfnInputShape ), highlevelShape( highlevelInputShape )
{
if( instanceCounter == 0 ){
initializePythonAPI();
loadPythonModule( "kerasPredict" );
}
++instanceCounter;
// if pfnInputShape is {0, 0}, the model we want to read is a BDT model which doesn't have any pfnInputShape variables
if( pfnInputShape.first == 0 and pfnInputShape.second == 0 ){
isPFN = false;
loadKerasModel("xgboostModel", model_file_name );
}else {
isPFN = true;
loadKerasModel("kerasModel", model_file_name );
}
}
PFNReader::PFNReader( const PFNReader& rhs) :
pfnShape( rhs.pfnShape ),
highlevelShape( rhs.highlevelShape ),
kerasModel( rhs.kerasModel ),
predictRoutine( rhs.predictRoutine )
{
++instanceCounter;
}
PFNReader::~PFNReader(){
--instanceCounter;
}
void PFNReader::initializePythonAPI(){
Py_Initialize();
PyObject* sysPath = PySys_GetObject( "path" );
std::string file_path = __FILE__;
std::string dir_path = file_path.substr(0, file_path.rfind("/"));
PyList_Insert( sysPath, 0, PyUnicode_FromString( dir_path.c_str() ));
PyList_Insert( sysPath, 0, PyUnicode_FromString("./"));
}
void PFNReader::loadPythonModule( const std::string& module_name ){
try{
pythonModule = boost::python::import( module_name.c_str() );
} catch( ... ){
PyErr_Print();
}
}
void PFNReader::loadKerasModel( const char * model_type, const std::string& model_name ){
try{
kerasModel = pythonModule.attr( model_type )( model_name );
predictRoutine = kerasModel.attr("predict");
} catch( ... ){
PyErr_Print();
}
}
template < typename T > boost::python::list vectorToPythonList( const std::vector< T >& vec ){
boost::python::list pyList;
for(const T& entry : vec ){
pyList.append( entry );
}
return pyList;
}
bool PFNReader::checkPfnShape( const std::vector< std::vector< double > >& vec) const{
if( vec.size() != pfnShape.first ) return false;
if( vec.size() != 0 && vec[0].size() != pfnShape.second ) return false;
return true;
}
bool PFNReader::checkHighlevelShape( const std::vector< double >& vec) const{
return ( vec.size() == highlevelShape );
}
double PFNReader::predict( const std::vector< double >& highlevelInput, const std::vector< std::vector< double > >& pfnInput ) const{
if( isPFN ) return predictPFN( highlevelInput, pfnInput );
else return predictBDT( highlevelInput );
}
double PFNReader::predictPFN( const std::vector< double >& highlevelInput, const std::vector< std::vector< double > >& pfnInput ) const{
if( !checkPfnShape( pfnInput ) ){
throw std::invalid_argument( "PFN input vector has wrong shape. Shape is supposed to be (" + std::to_string( pfnShape.first ) + ", " + std::to_string( pfnShape.second ) + ")." );
}
if( !checkHighlevelShape( highlevelInput ) ){
throw std::invalid_argument( "High level input vector has wrong shape. Shape is supposed to be (" + std::to_string( highlevelShape ) + ")." );
}
std::vector< boost::python::list > pfn_list_vector;
for( const auto& vec : pfnInput ){
pfn_list_vector.push_back( vectorToPythonList( vec ) );
}
boost::python::list pfn_list = vectorToPythonList( pfn_list_vector );
boost::python::list highlevel_list = vectorToPythonList( highlevelInput );
boost::python::object pythonPrediction = predictRoutine( highlevel_list, pfn_list );
return boost::python::extract< double >( pythonPrediction );
}
double PFNReader::predictBDT( const std::vector< double >& highlevelInput ) const{
if( !checkHighlevelShape( highlevelInput ) ){
throw std::invalid_argument( "BDT input vector has wrong shape. Shape is supposed to be (" + std::to_string( highlevelShape ) + ")." );
}
boost::python::list highlevel_list = vectorToPythonList( highlevelInput );
boost::python::object pythonPrediction = predictRoutine( highlevel_list );
return boost::python::extract< double >( pythonPrediction );
}