forked from alteryx/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Manish Amde <manish9ue@gmail.com>
- Loading branch information
1 parent
12738c1
commit cd53eae
Showing
12 changed files
with
318 additions
and
0 deletions.
There are no files selected for viewing
21 changes: 21 additions & 0 deletions
21
mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.classification | ||
|
||
class ClassificationTree { | ||
|
||
} |
21 changes: 21 additions & 0 deletions
21
mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.regression | ||
|
||
class RegressionTree { | ||
|
||
} |
54 changes: 54 additions & 0 deletions
54
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree | ||
|
||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.mllib.regression.LabeledPoint | ||
import org.apache.spark.mllib.tree.model.{Split, Bin, DecisionTreeModel} | ||
|
||
|
||
class DecisionTree(val strategy : Strategy) { | ||
|
||
def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { | ||
|
||
//Cache input RDD for speedup during multiple passes | ||
input.cache() | ||
|
||
//TODO: Find all splits and bins using quantiles including support for categorical features, single-pass | ||
val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) | ||
|
||
//TODO: Level-wise training of tree and obtain Decision Tree model | ||
|
||
|
||
return new DecisionTreeModel() | ||
} | ||
|
||
} | ||
|
||
object DecisionTree { | ||
def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { | ||
val numSplits = strategy.numSplits | ||
//TODO: Justify this calculation | ||
val requiredSamples : Long = numSplits*numSplits | ||
val count : Long = input.count() | ||
val numSamples : Long = if (requiredSamples < count) requiredSamples else count | ||
val numFeatures = input.take(1)(0).features.length | ||
(Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits)) | ||
} | ||
|
||
} |
15 changes: 15 additions & 0 deletions
15
mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
This package contains the default implementation of the decision tree algorithm. | ||
|
||
The decision tree algorithm supports: | ||
+ information loss calculation with entropy and gini for classification and variance for regression | ||
+ node model pruning | ||
+ printing to dot files | ||
+ unit tests | ||
|
||
#Performance testing | ||
|
||
#Future Extensions | ||
|
||
+ Random forests | ||
+ Boosting | ||
+ Extremely randomized trees |
28 changes: 28 additions & 0 deletions
28
mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.tree | ||
|
||
import org.apache.spark.mllib.tree.impurity.Impurity | ||
|
||
class Strategy ( | ||
val kind : String, | ||
val impurity : Impurity, | ||
val maxDepth : Int, | ||
val numSplits : Int, | ||
val quantileCalculationStrategy : String = "sampleAndSort") { | ||
|
||
} |
34 changes: 34 additions & 0 deletions
34
mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.tree.impurity | ||
|
||
object Entropy extends Impurity { | ||
|
||
def log2(x: Double) = scala.math.log(x) / scala.math.log(2) | ||
|
||
def calculate(c0: Double, c1: Double): Double = { | ||
if (c0 == 0 || c1 == 0) { | ||
0 | ||
} else { | ||
val total = c0 + c1 | ||
val f0 = c0 / total | ||
val f1 = c1 / total | ||
-(f0 * log2(f0)) - (f1 * log2(f1)) | ||
} | ||
} | ||
|
||
} |
28 changes: 28 additions & 0 deletions
28
mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.tree.impurity | ||
|
||
object Gini extends Impurity { | ||
|
||
def calculate(c0 : Double, c1 : Double): Double = { | ||
val total = c0 + c1 | ||
val f0 = c0 / total | ||
val f1 = c1 / total | ||
1 - f0*f0 - f1*f1 | ||
} | ||
|
||
} |
23 changes: 23 additions & 0 deletions
23
mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.tree.impurity | ||
|
||
trait Impurity { | ||
|
||
def calculate(c0 : Double, c1 : Double): Double | ||
|
||
} |
23 changes: 23 additions & 0 deletions
23
mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.tree.impurity | ||
|
||
import javax.naming.OperationNotSupportedException | ||
|
||
object Variance extends Impurity { | ||
def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") | ||
} |
21 changes: 21 additions & 0 deletions
21
mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.tree.model | ||
|
||
case class Bin(kind : String, lowSplit : Split, highSplit : Split) { | ||
|
||
} |
21 changes: 21 additions & 0 deletions
21
mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.tree.model | ||
|
||
class DecisionTreeModel { | ||
|
||
} |
29 changes: 29 additions & 0 deletions
29
mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.mllib.tree.model | ||
|
||
case class Split( | ||
val feature: Int, | ||
val threshold : Double, | ||
val kind : String) { | ||
|
||
} | ||
|
||
class dummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind) | ||
|
||
class dummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind) | ||
|