diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala index f986c0cb95348..c0fea6ad95d9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala @@ -18,6 +18,8 @@ package org.apache.spark.mllib.export import java.io.OutputStream +import java.io.FileOutputStream +import java.io.File trait ModelExport { @@ -25,5 +27,12 @@ trait ModelExport { * Write the exported model to the output stream specified */ def save(outputStream: OutputStream): Unit + + /** + * Write the exported model to the local file specified + */ + def saveLocalFile(path: String): Unit = { + save(new FileOutputStream(new File(path))); + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala new file mode 100644 index 0000000000000..fc627fcb75584 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.export + +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.linalg.Vectors +import org.scalatest.FunSuite +import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport + +class ModelExportFactorySuite extends FunSuite{ + + test("ModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { + + val clusterCenters = Array( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0) + ) + + val kmeansModel = new KMeansModel(clusterCenters); + + val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML) + + assert(modelExport.isInstanceOf[KMeansPMMLModelExport]) + + } + + test("ModelExportFactory generate IllegalArgumentException when passing an unsupported model") { + + val invalidModel = new Object; + + intercept[IllegalArgumentException] { + ModelExportFactory.createModelExport(invalidModel, ModelExportType.PMML) + } + + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala new file mode 100644 index 0000000000000..02339b0e20e28 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.export.pmml + +import org.scalatest.FunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.export.ModelExportFactory +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.export.ModelExportType + +class KMeansPMMLModelExportSuite extends FunSuite{ + + test("KMeansPMMLModelExport generate PMML format") { + + val clusterCenters = Array( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0) + ) + + val kmeansModel = new KMeansModel(clusterCenters); + + val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML) + + assert(modelExport.isInstanceOf[PMMLModelExport]) + + //TODO: asserts + //compare pmml fields to strings + modelExport.asInstanceOf[PMMLModelExport].getPmml() + //use document builder to load the xml generated and validated the notes by looking for them + modelExport.asInstanceOf[PMMLModelExport].save(System.out) + //saveLocalFile too??? search how to unit test file creating in java + + } + +}