diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index e469a7ac9bc85..34b447584e521 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -48,12 +48,13 @@ private[mllib] class BinaryClassificationPMMLModelExport( val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") var interceptNO = threshold if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { - if (threshold <= 0) - interceptNO = -1000 - else if (threshold >= 1) - interceptNO = 1000 - else - interceptNO = -math.log(1/threshold -1) + if (threshold <= 0) { + interceptNO = Double.MinValue + } else if (threshold >= 1) { + interceptNO = Double.MaxValue + } else { + interceptNO = -math.log(1 / threshold - 1) + } } val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") val regressionModel = new RegressionModel() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index 965e2785c3acc..c16e83d6a067d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -47,13 +47,14 @@ private[mllib] object PMMLModelExportFactory { svm, "linear SVM", RegressionNormalizationMethodType.NONE, svm.getThreshold.getOrElse(0.0)) case logistic: LogisticRegressionModel => - if (logistic.numClasses == 2) + if (logistic.numClasses == 2) { new BinaryClassificationPMMLModelExport( logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT, logistic.getThreshold.getOrElse(0.5)) - else + } else { throw new IllegalArgumentException( "PMML Export not supported for Multinomial Logistic Regression") + } case _ => throw new IllegalArgumentException( "PMML Export not supported for model: " + model.getClass.getName)