Skip to content

Commit

Permalink
GBT varImportance (#138)
Browse files Browse the repository at this point in the history
* GBT varImportance, refactor enums

* version control in GBT cls example

* version condition for printing of result in GBT example
  • Loading branch information
Alexander-Makaryev authored Oct 18, 2019
1 parent 9e86424 commit de475aa
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 12 deletions.
35 changes: 30 additions & 5 deletions examples/gradient_boosted_classification_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,33 @@ def main(readcsv=read_csv, method='defaultDense'):
testfile = "./data/batch/df_classification_test.csv"

# Configure a training object (5 classes)
train_algo = d4p.gbt_classification_training(nClasses=nClasses,
maxIterations=maxIterations,
minObservationsInLeafNode=minObservationsInLeafNode,
featuresPerNode=nFeatures)
# previous version has different interface
from daal4py import __daal_link_version__ as dv
daal_version = tuple(map(int, (dv[0:4], dv[4:8])))
if daal_version < (2020,0):
train_algo = d4p.gbt_classification_training(nClasses=nClasses,
maxIterations=maxIterations,
minObservationsInLeafNode=minObservationsInLeafNode,
featuresPerNode=nFeatures)
else:
train_algo = d4p.gbt_classification_training(nClasses=nClasses,
maxIterations=maxIterations,
minObservationsInLeafNode=minObservationsInLeafNode,
featuresPerNode=nFeatures,
varImportance='weight|totalCover|cover|totalGain|gain')

# Read data. Let's use 3 features per observation
data = readcsv(infile, range(3), t=np.float32)
labels = readcsv(infile, range(3,4), t=np.float32)
train_result = train_algo.compute(data, labels)

# Now let's do some prediction
predict_algo = d4p.gbt_classification_prediction(5)
# previous version has different interface
if daal_version < (2020,0):
predict_algo = d4p.gbt_classification_prediction(nClasses=nClasses)
else:
predict_algo = d4p.gbt_classification_prediction(nClasses=nClasses,
resultsToEvaluate="computeClassLabels|computeClassProbabilities")
# read test data (with same #features)
pdata = readcsv(testfile, range(3), t=np.float32)
# now predict using the model from the training above
Expand All @@ -68,4 +83,14 @@ def main(readcsv=read_csv, method='defaultDense'):
(train_result, predict_result, plabels) = main()
print("\nGradient boosted trees prediction results (first 10 rows):\n", predict_result.prediction[0:10])
print("\nGround truth (first 10 rows):\n", plabels[0:10])
# these results are available only in new version
from daal4py import __daal_link_version__ as dv
daal_version = tuple(map(int, (dv[0:4], dv[4:8])))
if daal_version >= (2020,0):
print("\nGradient boosted trees prediction probabilities (first 10 rows):\n", predict_result.probabilities[0:10])
print("\nvariableImportanceByWeight:\n", train_result.variableImportanceByWeight)
print("\nvariableImportanceByTotalCover:\n", train_result.variableImportanceByTotalCover)
print("\nvariableImportanceByCover:\n", train_result.variableImportanceByCover)
print("\nvariableImportanceByTotalGain:\n", train_result.variableImportanceByTotalGain)
print("\nvariableImportanceByGain:\n", train_result.variableImportanceByGain)
print('All looks good!')
13 changes: 6 additions & 7 deletions generator/gen_daal4py.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from collections import defaultdict, OrderedDict
from jinja2 import Template
from .parse import parse_header, parse_version
from .wrappers import required, ignore, defaults, has_dist, ifaces, no_warn, no_constructor, add_setup, enum_maps, wrap_algo
from .wrappers import required, ignore, defaults, has_dist, ifaces, no_warn, no_constructor, add_setup, enum_maps, enum_params, wrap_algo
from .wrapper_gen import wrapper_gen, typemap_wrapper_template
from .format import mk_var

Expand Down Expand Up @@ -265,11 +265,13 @@ def get_all_attrs(self, ns, cls, attr, ons=None):


###############################################################################
def to_lltype(self, t):
def to_lltype(self, p, t):
"""
return low level (C++ type). Usually the same as input.
Only very specific cases need a conversion.
"""
if p in enum_params:
return enum_params[p]
if t in ['DAAL_UINT64']:
return 'ResultToComputeId'
return t
Expand All @@ -285,9 +287,6 @@ def to_hltype(self, ns, t):
'?' means we do not know what 't' is
For classes, we also add lookups in namespaces that DAAL C++ API finds through "using".
"""
if t in ['DAAL_UINT64']:
### FIXME
t = 'ResultToComputeId'
tns, tname = splitns(t)
if t in ['double', 'float', 'int', 'size_t',]:
return (t, 'stdtype', '')
Expand Down Expand Up @@ -681,10 +680,10 @@ def prepare_hlwrapper(self, ns, mode, func, no_dist, no_stream):
for p in all_params:
pns, tmp = splitns(p)
if not tmp.startswith('_') and not ignored(pns, tmp):
hlt = self.to_hltype(pns, all_params[p][0])
llt = self.to_lltype(p, all_params[p][0])
hlt = self.to_hltype(pns, llt)
if hlt and hlt[1] in ['stdtype', 'enum', 'class']:
(hlt, hlt_type, hlt_ns) = hlt
llt = self.to_lltype(all_params[p][0])
pval = None
if hlt_type == 'enum':
thetype = hlt_ns + '::' + llt.rsplit('::', 1)[-1]
Expand Down
10 changes: 10 additions & 0 deletions generator/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,16 @@ def wrap_algo(algo, ver):
'algorithms::pca::ResultToComputeId' : 'result_dataForTransform',
}

# Enums are used as a values to define bit-mask in Parameter
# Parameter itself defined as DAAL_UINT64, we can't determine possible values
# this dict shows what Enum contain a values for Parameter
# if such parameter is not in this dict then we think that it is 'ResultToComputeId'
# Parameter->Enum of values
enum_params = {
'algorithms::gbt::classification::training::varImportance': 'algorithms::gbt::training::VariableImportanceModes',
'algorithms::gbt::regression::training::varImportance': 'algorithms::gbt::training::VariableImportanceModes',
}

# The distributed algorithm configuration parameters
# Note that all have defaults and so are optional.
# In particular note that the name of a single input argument defaults to data.
Expand Down

0 comments on commit de475aa

Please sign in to comment.