diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index 0c374a46fb562..bd8c8f96a6e55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -44,7 +44,11 @@ private[mllib] object PMMLModelExportFactory { new GeneralizedLinearPMMLModelExport(svm, "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") case logistic: LogisticRegressionModel => - new LogisticRegressionPMMLModelExport(logistic, "logistic regression") + if(logistic.numClasses == 2) + new LogisticRegressionPMMLModelExport(logistic, "logistic regression") + else + throw new IllegalArgumentException( + "PMML Export not supported for Multinomial Logistic Regression") case _ => throw new IllegalArgumentException( "PMML Export not supported for model: " + model.getClass.getName) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index a94854e4c0f20..b87e96e7032f3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -73,6 +73,18 @@ class PMMLModelExportFactorySuite extends FunSuite { assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport]) } + + test("PMMLModelExportFactory throw IllegalArgumentException " + + "when passing a Multinomial Logistic Regression") { + /** 3 classes, 2 features */ + val multiclassLogisticRegressionModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + numFeatures = 2, numClasses = 3) + + intercept[IllegalArgumentException] { + PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) + } + } test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") { val invalidModel = new Object