diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala index 274366208bd36..f986c0cb95348 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala @@ -21,6 +21,9 @@ import java.io.OutputStream trait ModelExport { + /** + * Write the exported model to the output stream specified + */ def save(outputStream: OutputStream): Unit } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala index afce4e305aaac..7e2e76f53988c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala @@ -19,14 +19,22 @@ package org.apache.spark.mllib.export import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport +import org.apache.spark.mllib.export.ModelExportType._ object ModelExportFactory { - //TODO: introduce model export typed - - def createModelExport(model: Any): ModelExport = model match { - case kmeans: KMeansModel => new KMeansPMMLModelExport - case _ => throw new IllegalArgumentException("Export not supported for model " + model.getClass) + /** + * Factory object to help creating the necessary ModelExport implementation + * taking as input the ModelExportType (for example PMML) and the machine learning model (for example KMeansModel). + */ + def createModelExport(model: Any, exportType: ModelExportType): ModelExport = { + return exportType match{ + case PMML => model match{ + case kmeans: KMeansModel => new KMeansPMMLModelExport(kmeans) + case _ => throw new IllegalArgumentException("Export not supported for model: " + model.getClass) + } + case _ => throw new IllegalArgumentException("Export type not supported:" + exportType) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala index 5abb7d6bb4e71..1e940a6aa5e50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala @@ -17,6 +17,10 @@ package org.apache.spark.mllib.export +/** + * Defines export types. + * - PMML exports the machine learning models in an XML-based file format called Predictive Model Markup Language developed by the Data Mining Group (www.dmg.org). + */ object ModelExportType extends Enumeration{ type ModelExportType = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala index f53443e3e646d..99ab256adfd0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala @@ -17,11 +17,19 @@ package org.apache.spark.mllib.export.pmml -class KMeansPMMLModelExport extends PMMLModelExport{ +import org.apache.spark.mllib.clustering.KMeansModel - populateKMeansPMML(); +/** + * PMML Model Export for KMeansModel class + */ +class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{ + + /** + * Export the input KMeansModel model to PMML format + */ + populateKMeansPMML(model); - def populateKMeansPMML(): Unit = { + private def populateKMeansPMML(model : KMeansModel): Unit = { //TODO: set here header description pmml.setVersion("testing... kmeans..."); //TODO: generate the model... diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala index 42203e6b9291a..6d8e8ff0797f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala @@ -26,10 +26,17 @@ import scala.beans.BeanProperty trait PMMLModelExport extends ModelExport{ + /** + * Holder of the exported model in PMML format + */ @BeanProperty var pmml: PMML = new PMML(); //TODO: set here header app copyright and timestamp + /** + * Write the exported model (in PMML XML) to the output stream specified + */ + @Override def save(outputStream: OutputStream): Unit = { JAXBUtil.marshalPMML(pmml, new StreamResult(outputStream)); }