diff --git a/README.md b/README.md index 436db09..5cf3247 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,108 @@ # Spark Model Helper - Extracting useful information from trained Spark Model + +A helper that extracting useful information from trained Spark Model + +Have you tired staring at the model.toDebugString() for hours and getting no clue at all? Something like this: +```text +DecisionTreeClassificationModel (uid=dtc_e933455b) of depth 5 with 341 nodes + If (feature 518 <= 1.5) + If (feature 6 <= 2.5) + If (feature 30 <= 20) + If (feature 45 <= 7) + If (feature 24 <= 2.5) + ... + Else (feature 160 > 0.99) + If (feature 64 <= 3.5) + Predict: 0.0 + Else (feature 64 > 3.5) + Predict: 1.0 +``` +Well now you have this helper designed for HUMAN, which matters. + +## Usage: + +### **DecisionTreeClassificationModel Analysis** +*** + +#### _1. Get Root Feature Index from trained model_ + +In automated ML, sometimes you need to retrain model due to the scalar metric as evaluation result (precision and recall +for instance) are not within a desired range. By using rootFeatureIndex we can change the dataframe accordingly. + +```scala + val helper = new DecisionModelHelper(model) + println("Root Feature Index: " + helper.getRootFeature) +``` + +#### _2. Get the JSON string from DecisionTreeClassificationModel_ + +```scala + val helper = new DecisionModelHelper(model) + println("toJson: " + helper.toJson + "\n") +``` + +Return beatified JSON string: + +```json + { + "featureIndex": 367, + "gain": 0.10617781879627633, + "impurity": 0.3144732024264778, + "threshold": 1.5, + "nodeType": "internal", + "splitType": "continuous", + "prediction": 0.0, + "leftChild": { + .... + "path": "F(367)|0.3144732024264778|0.0|1.5|" + } +``` + +### _3. Return an Object of Model Node-Tree_ + +```scala + val nodeObj = helper.getDecisionNode + println("getDecisionNode: " + nodeObj) +``` + +Object version of JSON String + +### _4. Return Root to Leaf Path of Rules_ + +```scala + val rules = helper.getRulesList(1, 0.2) + rules.foreach(rule => { + println("Rule: " + rule.mkString(", ")) + }) +``` +The above code prints: +```text + Rule: F(396)|0.3144732024264778|0.0|1.5|L, F(12)|0.49791192623975383|1.0|2.5|R, F(223)|0.2998340735773348|1.0|2500000.0|R, F(20)|0.19586076183802947|1.0|3.523971665E10|L, None|0.1902980108641974|1.0|None|E +``` +The function returns List[List[String]], the structure of each String is +```text + Feature_Index | impurity | prediction | threshold | node_type +``` +For example, Feature Index **45** has impurity **3.5**, prediction **1**, threshold **1.5** and the path goes **right** after this node, the string will be: +```text + F(45)|3.5|1|1.5|R +``` +The Leaf nodes will have **"E"** as node_type + +### _5. Customise the Feature_Index_ +Feature Index is not designed for human reading, especially with large amount of columns. +The helper also supports customisation of Features +```scala + val helper = new DecisionModelHelper(model) + helper.setFeatureName( + Map(0 -> "UserID", 1 -> "UserCity", 2 -> "Salary" ... ) + ) +``` +The helper automatically change the F(1) into "UserCity" upon called setFeatureName(Map[Int, String]) +```text + F(1)|3.5|1|1.5|R +``` +will output as +```text + UserCity|3.5|1|1.5|R +``` \ No newline at end of file diff --git a/build.sbt b/build.sbt index 3b9ae02..6f1306a 100644 --- a/build.sbt +++ b/build.sbt @@ -1,13 +1,14 @@ import sbt.Keys.libraryDependencies -ThisBuild / version := "1.1.0-SNAPSHOT" - +ThisBuild / version := "1.1.0" +ThisBuild / versionScheme := Some("pvp") ThisBuild / scalaVersion := "2.12.15" + lazy val root = (project in file(".")) .settings( name := "SparkModelHelper", - idePackagePrefix := Some("io.github.RaistlinTao"), + idePackagePrefix := Some("io.github.raistlintao"), // https://mvnrepository.com/artifact/net.liftweb/lift-json libraryDependencies += "net.liftweb" %% "lift-json" % "3.5.0", // https://mvnrepository.com/artifact/com.alibaba/fastjson diff --git a/src/test/scala/Usage.scala b/src/test/scala/Usage.scala index b6d4a65..42bd10d 100644 --- a/src/test/scala/Usage.scala +++ b/src/test/scala/Usage.scala @@ -13,14 +13,10 @@ object Usage { println("toDebugString:" + model.toDebugString + "\n") val helper = new DecisionModelHelper(model) helper.setFeatureName( - Map(0 -> "F0", 1 -> "F1", 2 -> "F0", 3 -> "F1", 4 -> "F0", 5 -> "F1", 6 -> "F0", 7 -> "F1", 8 -> "F0", 9 -> "F1", - 10 -> "DSF", 11 -> "ZCX", 12 -> "CSD", 13 -> "GF", 14 -> "F0", 15 -> "F1", 16 -> "F0", 17 -> "F1", 18 -> "F0", 19 -> "F1", - 20 -> "SC", 21 -> "CZX", 22 -> "KJL", 23 -> "GF", 24 -> "FG", 25 -> "F1", 26 -> "FFG0", 27 -> "GF1", 28 -> "FM0", 29 -> "FV1", - 30 -> "CZX", 31 -> "XCZ", 32 -> "HGJ", 33 -> "BB", 34 -> "F0", 35 -> "F1", 36 -> "F0", 37 -> "FA1", 38 -> "F0V", 39 -> "FV1", - 40 -> "XZ", 41 -> "XC", 42 -> "HGJ", 43 -> "BB", 44 -> "F0", 45 -> "F1", 46 -> "F0D", 47 -> "F1A", 48 -> "FV0", 49 -> "FV1", - 50 -> "CZ", 51 -> "CX", 52 -> "HN", 53 -> "GG", 54 -> "F0", 55 -> "F1", 56 -> "F0", 57 -> "F1" - ) + Map(0 -> "F0", 1 -> "F1", 2 -> "F0", 3 -> "F1", 4 -> "F0") ) + val root_feature_index = helper.getRootFeature + println("Root Feature Index: " + root_feature_index) val jsonStr = helper.toJson println("toJson: " + jsonStr + "\n") val nodeObj = helper.getDecisionNode