Skip to content

Commit

Permalink
Release Version 1.1, Updating Document
Browse files Browse the repository at this point in the history
  • Loading branch information
RaistlinTAO committed Dec 29, 2021
1 parent 8fe4283 commit 35481ac
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 11 deletions.
108 changes: 107 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,108 @@
# Spark Model Helper
Extracting useful information from trained Spark Model

A helper that extracting useful information from trained Spark Model

Have you tired staring at the model.toDebugString() for hours and getting no clue at all? Something like this:
```text
DecisionTreeClassificationModel (uid=dtc_e933455b) of depth 5 with 341 nodes
If (feature 518 <= 1.5)
If (feature 6 <= 2.5)
If (feature 30 <= 20)
If (feature 45 <= 7)
If (feature 24 <= 2.5)
...
Else (feature 160 > 0.99)
If (feature 64 <= 3.5)
Predict: 0.0
Else (feature 64 > 3.5)
Predict: 1.0
```
Well now you have this helper designed for HUMAN, which matters.

## Usage:

### **DecisionTreeClassificationModel Analysis**
***

#### _1. Get Root Feature Index from trained model_

In automated ML, sometimes you need to retrain model due to the scalar metric as evaluation result (precision and recall
for instance) are not within a desired range. By using rootFeatureIndex we can change the dataframe accordingly.

```scala
val helper = new DecisionModelHelper(model)
println("Root Feature Index: " + helper.getRootFeature)
```

#### _2. Get the JSON string from DecisionTreeClassificationModel_

```scala
val helper = new DecisionModelHelper(model)
println("toJson: " + helper.toJson + "\n")
```

Return beatified JSON string:

```json
{
"featureIndex": 367,
"gain": 0.10617781879627633,
"impurity": 0.3144732024264778,
"threshold": 1.5,
"nodeType": "internal",
"splitType": "continuous",
"prediction": 0.0,
"leftChild": {
....
"path": "F(367)|0.3144732024264778|0.0|1.5|"
}
```

### _3. Return an Object of Model Node-Tree_

```scala
val nodeObj = helper.getDecisionNode
println("getDecisionNode: " + nodeObj)
```

Object version of JSON String

### _4. Return Root to Leaf Path of Rules_

```scala
val rules = helper.getRulesList(1, 0.2)
rules.foreach(rule => {
println("Rule: " + rule.mkString(", "))
})
```
The above code prints:
```text
Rule: F(396)|0.3144732024264778|0.0|1.5|L, F(12)|0.49791192623975383|1.0|2.5|R, F(223)|0.2998340735773348|1.0|2500000.0|R, F(20)|0.19586076183802947|1.0|3.523971665E10|L, None|0.1902980108641974|1.0|None|E
```
The function returns List[List[String]], the structure of each String is
```text
Feature_Index | impurity | prediction | threshold | node_type
```
For example, Feature Index **45** has impurity **3.5**, prediction **1**, threshold **1.5** and the path goes **right** after this node, the string will be:
```text
F(45)|3.5|1|1.5|R
```
The Leaf nodes will have **"E"** as node_type

### _5. Customise the Feature_Index_
Feature Index is not designed for human reading, especially with large amount of columns.
The helper also supports customisation of Features
```scala
val helper = new DecisionModelHelper(model)
helper.setFeatureName(
Map(0 -> "UserID", 1 -> "UserCity", 2 -> "Salary" ... )
)
```
The helper automatically change the F(1) into "UserCity" upon called setFeatureName(Map[Int, String])
```text
F(1)|3.5|1|1.5|R
```
will output as
```text
UserCity|3.5|1|1.5|R
```
7 changes: 4 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import sbt.Keys.libraryDependencies

ThisBuild / version := "1.1.0-SNAPSHOT"

ThisBuild / version := "1.1.0"
ThisBuild / versionScheme := Some("pvp")
ThisBuild / scalaVersion := "2.12.15"


lazy val root = (project in file("."))
.settings(
name := "SparkModelHelper",
idePackagePrefix := Some("io.github.RaistlinTao"),
idePackagePrefix := Some("io.github.raistlintao"),
// https://mvnrepository.com/artifact/net.liftweb/lift-json
libraryDependencies += "net.liftweb" %% "lift-json" % "3.5.0",
// https://mvnrepository.com/artifact/com.alibaba/fastjson
Expand Down
10 changes: 3 additions & 7 deletions src/test/scala/Usage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@ object Usage {
println("toDebugString:" + model.toDebugString + "\n")
val helper = new DecisionModelHelper(model)
helper.setFeatureName(
Map(0 -> "F0", 1 -> "F1", 2 -> "F0", 3 -> "F1", 4 -> "F0", 5 -> "F1", 6 -> "F0", 7 -> "F1", 8 -> "F0", 9 -> "F1",
10 -> "DSF", 11 -> "ZCX", 12 -> "CSD", 13 -> "GF", 14 -> "F0", 15 -> "F1", 16 -> "F0", 17 -> "F1", 18 -> "F0", 19 -> "F1",
20 -> "SC", 21 -> "CZX", 22 -> "KJL", 23 -> "GF", 24 -> "FG", 25 -> "F1", 26 -> "FFG0", 27 -> "GF1", 28 -> "FM0", 29 -> "FV1",
30 -> "CZX", 31 -> "XCZ", 32 -> "HGJ", 33 -> "BB", 34 -> "F0", 35 -> "F1", 36 -> "F0", 37 -> "FA1", 38 -> "F0V", 39 -> "FV1",
40 -> "XZ", 41 -> "XC", 42 -> "HGJ", 43 -> "BB", 44 -> "F0", 45 -> "F1", 46 -> "F0D", 47 -> "F1A", 48 -> "FV0", 49 -> "FV1",
50 -> "CZ", 51 -> "CX", 52 -> "HN", 53 -> "GG", 54 -> "F0", 55 -> "F1", 56 -> "F0", 57 -> "F1"
)
Map(0 -> "F0", 1 -> "F1", 2 -> "F0", 3 -> "F1", 4 -> "F0")
)
val root_feature_index = helper.getRootFeature
println("Root Feature Index: " + root_feature_index)
val jsonStr = helper.toJson
println("toJson: " + jsonStr + "\n")
val nodeObj = helper.getDecisionNode
Expand Down

0 comments on commit 35481ac

Please sign in to comment.