From d10babb3c52a10bac0aacace97cc06cba3c2a501 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sat, 2 May 2015 02:17:47 -0700 Subject: [PATCH] addressed comments v0.2 --- python/pyspark/sql/dataframe.py | 8 +++++--- .../org/apache/spark/sql/DataFrameStatFunctions.scala | 2 +- .../org/apache/spark/sql/DataFrameStatSuite.scala | 10 ++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5af12d2250ffd..d10fe98133993 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -875,9 +875,9 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) - def corr(self, col1, col2, method="pearson"): + def corr(self, col1, col2, method=None): """ - Calculate the correlation of two columns of a DataFrame as a double value. Currently only + Calculates the correlation of two columns of a DataFrame as a double value. Currently only supports the Pearson Correlation Coefficient. :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases. @@ -889,6 +889,8 @@ def corr(self, col1, col2, method="pearson"): raise ValueError("col1 should be a string.") if not isinstance(col2, str): raise ValueError("col2 should be a string.") + if not method: + method = "pearson" if not method == "pearson": raise ValueError("Currently only the calculation of the Pearson Correlation " + "coefficient is supported.") @@ -1378,7 +1380,7 @@ class DataFrameStatFunctions(object): def __init__(self, df): self.df = df - def corr(self, col1, col2, method="pearson"): + def corr(self, col1, col2, method=None): return self.df.corr(col1, col2, method) corr.__doc__ = DataFrame.corr.__doc__ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 56630b794ef3f..903532105284e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -37,7 +37,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @return The Pearson Correlation Coefficient as a Double. */ def corr(col1: String, col2: String, method: String): Double = { - assert(method == "pearson", "Currently only the calculation of the Pearson Correlation " + + require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + "coefficient is supported.") StatFunctions.pearsonCorrelation(df, Seq(col1, col2)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index ef80545112bb6..6a9f5f945d953 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -30,10 +30,10 @@ class DataFrameStatSuite extends FunSuite { def toLetter(i: Int): String = (i + 97).toChar.toString test("Frequent Items") { - val rows = Array.tabulate(1000) { i => + val rows = Seq.tabulate(1000) { i => if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) } - val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles") + val df = rows.toDF("numbers", "letters", "negDoubles") val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) val items = results.collect().head @@ -46,8 +46,7 @@ class DataFrameStatSuite extends FunSuite { } test("pearson correlation") { - val df = sqlCtx.sparkContext.parallelize( - Array.tabulate(10)(i => (i, 2 * i, i * -1.0))).toDF("a", "b", "c") + val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") assert(math.abs(corr1 - 1.0) < 1e-6) val corr2 = df.stat.corr("a", "c", "pearson") @@ -55,8 +54,7 @@ class DataFrameStatSuite extends FunSuite { } test("covariance") { - val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i))) - val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters") + val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters") val results = df.stat.cov("singles", "doubles") assert(math.abs(results - 55.0 / 3) < 1e-6)