From 9662c9ee31a392f38181d4eb88a6cab7a18927d2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 26 Jan 2015 23:48:23 -0800 Subject: [PATCH] improve DataFrame Python API --- python/pyspark/sql.py | 334 +++++++++++++----- python/pyspark/tests.py | 22 +- .../org/apache/spark/sql/DataFrame.scala | 18 +- .../apache/spark/sql/GroupedDataFrame.scala | 21 +- 4 files changed, 283 insertions(+), 112 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f16eb361d306f..518901d142f56 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1810,20 +1810,35 @@ def inherit_doc(cls): class DataFrame(object): - """An RDD of L{Row} objects that has an associated schema. + """A collection of rows that have the same columns. - The underlying JVM object is a DataFrame, so we can - utilize the relational query api exposed by Spark SQL. + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: - For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the - L{DataFrame} is not operated on directly, as it's underlying - implementation is an RDD composed of Java objects. Instead it is - converted to a PythonRDD in the JVM, on which Python operations can - be done. + people = sqlContext.parquetFile("...") - This class receives raw tuples from Java but assigns a class to it in - all its data-collection methods (mapPartitionsWithIndex, collect, take, - etc) so that PySpark sees them as Row objects with named fields. + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: [[DataFrame]], [[Column]]. + + To select a column from the data frame, use the apply method:: + + ageCol = people.age + + Note that the :class:`Column` type can also be manipulated + through its various functions:: + + # The following creates a new column that increases everybody's age by 10. + people.age + 10 + + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupby(department.name, "gender").agg({"salary": "avg", "age": "max"}) """ def __init__(self, jdf, sql_ctx): @@ -1834,11 +1849,8 @@ def __init__(self, jdf, sql_ctx): @property def rdd(self): - """Lazy evaluation of PythonRDD object. - - Only done when a user calls methods defined by the - L{pyspark.rdd.RDD} super class (map, filter, etc.). - """ + """Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row`s. """ if not hasattr(self, '_lazy_rdd'): jrdd = self._jdf.javaToPython() rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) @@ -1931,13 +1943,9 @@ def schema(self): a L{StructType}).""" return _parse_datatype_json_string(self._jdf.schema().json()) - # def schemaString(self): - # """Returns the output schema in the tree format.""" - # return self._jdf.schemaString() - # def printSchema(self): """Prints out the schema in the tree format.""" - self.printSchema() + print (self._jdf.schema().treeString()) def count(self): """Return the number of elements in this RDD. @@ -1988,9 +1996,6 @@ def take(self, num): """ return self.limit(num).collect() - def first(self): - return self.rdd.first() - def map(self, f): return self.rdd.map(f) @@ -2031,34 +2036,10 @@ def unpersist(self, blocking=True): # rdd = self._jdf.coalesce(numPartitions, shuffle, None) # return DataFrame(rdd, self.sql_ctx) - # def distinct(self, numPartitions=None): - # if numPartitions is None: - # rdd = self._jdf.distinct() - # else: - # rdd = self._jdf.distinct(numPartitions, None) - # return DataFrame(rdd, self.sql_ctx) - # - # def intersection(self, other): - # if (other.__class__ is DataFrame): - # rdd = self._jdf.intersection(other._jdf) - # return DataFrame(rdd, self.sql_ctx) - # else: - # raise ValueError("Can only intersect with another DataFrame") - # def repartition(self, numPartitions): # rdd = self._jdf.repartition(numPartitions, None) # return DataFrame(rdd, self.sql_ctx) # - # def subtract(self, other, numPartitions=None): - # if (other.__class__ is DataFrame): - # if numPartitions is None: - # rdd = self._jdf.subtract(other._jdf) - # else: - # rdd = self._jdf.subtract(other._jdf, - # numPartitions) - # return DataFrame(rdd, self.sql_ctx) - # else: - # raise ValueError("Can only subtract another DataFrame") def sample(self, withReplacement, fraction, seed=None): """ @@ -2090,16 +2071,31 @@ def sample(self, withReplacement, fraction, seed=None): @property def dtypes(self): + """Return all column names and their data types as a list. + """ return [(f.name, str(f.dataType)) for f in self.schema().fields] @property def columns(self): + """ Return all column names as a list. + """ return [f.name for f in self.schema().fields] def show(self): raise NotImplemented def join(self, other, joinExprs=None, joinType=None): + """ + Join with another DataFrame, using the given join expression. + The following performs a full outer join between `df1` and `df2`:: + + df1.join(df2, df1.key == df2.key, "outer") + + :param other: Right side of the join + :param joinExprs: Join expression + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, + `semijoin`. + """ if joinType is None: if joinExprs is None: jdf = self._jdf.join(other._jdf) @@ -2110,6 +2106,11 @@ def join(self, other, joinExprs=None, joinType=None): return DataFrame(jdf, self.sql_ctx) def sort(self, *cols): + """ Return a new [[DataFrame]] sorted by the specified column, + in ascending column. + + :param cols: The columns or expressions used for sorting + """ if not cols: raise ValueError("should sort by at least one column") for i, c in enumerate(cols): @@ -2119,24 +2120,43 @@ def sort(self, *cols): jdf = self._jdf.join(*jcols) return DataFrame(jdf, self.sql_ctx) + sortBy = sort + + def head(self, n=None): + """ Return the first `n` rows or the first row if n is None. """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + def tail(self): raise NotImplemented def __getitem__(self, item): if isinstance(item, basestring): return Column(self._jdf.apply(item)) + # TODO projection raise IndexError def __getattr__(self, name): + """ Return the column by given name """ if isinstance(name, basestring): return Column(self._jdf.apply(name)) raise AttributeError - def alias(self, name): + def As(self, name): + """ Alias the current DataFrame """ return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) def select(self, *cols): + """ Selecting a set of expressions.:: + + df.select() + df.select('colA', 'colB') + df.select(df.colA, df.colB + 1) + + """ if not cols: cols = ["*"] if isinstance(cols[0], basestring): @@ -2144,34 +2164,98 @@ def select(self, *cols): else: cols = [c._jc for c in cols] jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) - jdf = self._jdf.select(jcols) + jdf = self._jdf.select(self._jdf.toColumnArray(jcols)) return DataFrame(jdf, self.sql_ctx) - def where(self, cond): - return DataFrame(self._jdf.filter(cond._jc), self.sql_ctx) + def filter(self, condition): + """ Filtering rows using the given condition:: + + df.filter(df.age > 15) + df.where(df.age > 15) - def filter(self, col): - return DataFrame(self._jdf.filter(col._jc), self.sql_ctx) + """ + return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) + + where = filter def groupby(self, *cols): + """ Group the [[DataFrame]] using the specified columns, + so we can run aggregation on them. See :class:`GroupedDataFrame` + for all the available aggregate functions:: + + df.groupby(df.department).avg() + df.groupby("department", "gender").agg({ + "salary": "avg", + "age": "max", + }) + """ if cols and isinstance(cols[0], basestring): cols = [_create_column_from_name(n) for n in cols] else: cols = [c._jc for c in cols] jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) - jdf = self._jdf.groupby(jcols) + jdf = self._jdf.groupby(self._jdf.toColumnArray(jcols)) return GroupedDataFrame(jdf, self.sql_ctx) + def agg(self, *exprs): + """ Aggregate on the entire [[DataFrame]] without groups + (shorthand for df.groupby.agg()):: + + df.agg({"age": "max", "salary": "avg"}) + """ + return self.groupby().agg(*exprs) + + def unionAll(self, other): + """ Return a new DataFrame containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. + """ + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + + def intersect(self, other): + """ Return a new [[DataFrame]] containing rows only in + both this frame and another frame. + + This is equivalent to `INTERSECT` in SQL. + """ + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def Except(self, other): + """ Return a new [[DataFrame]] containing rows in this frame + but not in another frame. -# make SchemaRDD as an alias of DataFrame for backward compatibility -SchemaRDD = DataFrame + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ Return a new DataFrame by sampling a fraction of rows. """ + if seed is None: + jdf = self._jdf.sample(withReplacement, fraction) + else: + jdf = self._jdf.sample(withReplacement, fraction, seed) + return DataFrame(jdf, self.sql_ctx) + + def addColumn(self, colName, col): + """ Return a new [[DataFrame]] by adding a column. """ + return self.select('*', col.As(colName)) + + def removeColumn(self, colName): + raise NotImplemented + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """ + SchemaRDD is deprecated, please use DataFrame + """ def dfapi(f): - def _api(self, *a): - ja = [v._jc if isinstance(v, Column) else v for v in a] + def _api(self): name = f.__name__ - jdf = getattr(self._jdf, name)(*ja) + jdf = getattr(self._jdf, name)() return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -2179,13 +2263,29 @@ def _api(self, *a): class GroupedDataFrame(object): + + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by DataFrame.groupby(). + """ + def __init__(self, jdf, sql_ctx): self._jdf = jdf self.sql_ctx = sql_ctx def agg(self, *exprs): + """ Compute aggregates by specifying a map from column name + to aggregate methods. + + The available aggregate methods are `avg`, `max`, `min`, + `sum`, `count`. + + :param exprs: list or aggregate columns or a map from column + name to agregate methods. + """ if len(exprs) == 1 and isinstance(exprs[0], dict): - jmap = MapConverter().convert(exprs[0], self.sql_ctx._sc._gateway._gateway_client) + jmap = MapConverter().convert(exprs[0], + self.sql_ctx._sc._gateway._gateway_client) jdf = self._jdf.agg(jmap) else: # Columns @@ -2195,23 +2295,32 @@ def agg(self, *exprs): @dfapi def count(self): - """ """ + """ Count the number of rows for each group. """ @dfapi def mean(self): - """""" + """Compute the average value for each numeric columns + for each group. This is an alias for `avg`.""" + + @dfapi + def avg(self): + """Compute the average value for each numeric columns + for each group.""" @dfapi def max(self): - """""" + """Compute the max value for each numeric columns for + each group. """ @dfapi def min(self): - """""" + """Compute the min value for each numeric column for + each group.""" @dfapi def sum(self): - """""" + """Compute the sum for each numeric columns for each + group.""" SCALA_METHOD_MAPPINGS = { @@ -2246,40 +2355,67 @@ def _create_column_from_name(name): return sc._jvm.Column(name) -def scalaMethod(name): +def _scalaMethod(name): + """ Translate operators into methodName in Scala + + For example: + >>> scalaMethod('+') + '$plus' + >>> scalaMethod('>=') + '$greater$eq' + >>> scalaMethod('cast') + 'cast' + """ return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) def _unary_op(name): + """ Create a method for given unary operator """ def _(self): - return Column(getattr(self._jc, scalaMethod(name))()) + return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx) return _ def _bin_op(name): + """ Create a method for given binary operator """ def _(self, other): if isinstance(other, Column): jc = other._jc else: jc = _create_column_from_literal(other) - return Column(getattr(self._jc, scalaMethod(name))(jc)) + return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx) return _ def _reverse_op(name): + """ Create a method for binary operator (this object is on right side) + """ def _(self, other): - return Column(getattr(_create_column_from_literal(other), scalaMethod(name))(self._jc)) + return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc), + self._jdf, self.sql_ctx) return _ class Column(DataFrame): + + """ + A column in a DataFrame. + + `Column` instances can be created by: + {{{ + // 1. Select a column out of a DataFrame + df.colName + df["colName"] + + // 2. Create from an expression + df["colName"] + 1 + }}} + """ + def __init__(self, jc, jdf=None, sql_ctx=None): self._jc = jc super(Column, self).__init__(jdf, sql_ctx) - def __nonzero__(self): - return True - # arithmetic operators __neg__ = _unary_op("unary_-") __add__ = _bin_op("+") @@ -2324,10 +2460,6 @@ def __nonzero__(self): # container operators __contains__ = _bin_op("contains") __getitem__ = _bin_op("getItem") - - def __getslice__(self, a, b): - jc = self._jsc.substr(a, b - a) - return Column(jc) # __getattr__ = _bin_op("getField") # string methods @@ -2338,19 +2470,43 @@ def __getslice__(self, a, b): upper = _unary_op("upper") lower = _unary_op("lower") + def substr(self, startPos, pos): + if type(startPos) != type(pos): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + + jc = self._jc.substr(startPos, pos) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, pos._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc, self._jdf, self.sql_ctx) + + __getslice__ = substr + # order asc = _unary_op("asc") desc = _unary_op("desc") + isNull = _unary_op("isNull") + isNotNull = _unary_op("isNotNull") + # `as` is keyword def As(self, alias): - return Column(getattr(self._jsc, "as")(alias)) + return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx) def cast(self, dataType): - raise NotImplemented + if self.sql_ctx is None: + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + else: + ssql_ctx = self.sql_ctx._ssql_ctx + jdt = ssql_ctx.parseDataType(dataType.json()) + return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx) -def _help_func(name): +def _aggregate_func(name): + """ Creat a function for aggregator by name""" def _(col): sc = SparkContext._active_spark_context if isinstance(col, Column): @@ -2364,14 +2520,16 @@ def _(col): class Aggregator(object): - # helper functions - max = _help_func("max") - min = _help_func("min") - avg = mean = _help_func("mean") - sum = _help_func("sum") - first = _help_func("first") - last = _help_func("last") - count = _help_func("count") + """ + A collections of builtin aggregators + """ + max = _aggregate_func("max") + min = _aggregate_func("min") + avg = mean = _aggregate_func("mean") + sum = _aggregate_func("sum") + first = _aggregate_func("first") + last = _aggregate_func("last") + count = _aggregate_func("count") def _test(): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 142d7ad5f0f9b..c8df2fc6ef956 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -873,7 +873,7 @@ def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) df = self.sqlCtx.inferSchema(rdd) - row = df.first() + row = df.head() self.assertEqual(1, len(row.l)) self.assertEqual(1, row.l[0].a) self.assertEqual("2", row.d["key"].d) @@ -899,7 +899,7 @@ def test_infer_schema(self): self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) df.registerTempTable("test") result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) + self.assertEqual(1, result.head()[0]) df2 = self.sqlCtx.inferSchema(rdd, 1.0) self.assertEqual(df.schema(), df2.schema()) @@ -907,13 +907,13 @@ def test_infer_schema(self): self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) df2.registerTempTable("test2") result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) + self.assertEqual(1, result.head()[0]) def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] rdd = self.sc.parallelize(d) df = self.sqlCtx.inferSchema(rdd) - k, v = df.first().m.items()[0] + k, v = df.head().m.items()[0] self.assertEqual(1, k.i) self.assertEqual("", v.s) @@ -923,7 +923,7 @@ def test_convert_row_to_dict(self): rdd = self.sc.parallelize([row]) df = self.sqlCtx.inferSchema(rdd) df.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").first() + row = self.sqlCtx.sql("select l, d from test").head() self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) @@ -936,7 +936,7 @@ def test_infer_schema_with_udt(self): field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) def test_apply_schema_with_udt(self): @@ -946,7 +946,8 @@ def test_apply_schema_with_udt(self): schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) df = self.sqlCtx.applySchema(rdd, schema) - point = df.first().point + # TODO: test collect with UDT + point = df.rdd.first().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_parquet_with_udt(self): @@ -957,11 +958,11 @@ def test_parquet_with_udt(self): output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.saveAsParquetFile(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) - point = df1.first().point + point = df1.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_column_operators(self): - from pyspark.sql import Column + from pyspark.sql import Column, LongType ci = self.df.key cs = self.df.value c = ci == cs @@ -974,8 +975,9 @@ def test_column_operators(self): self.assertTrue(all(isinstance(c, Column) for c in cbit)) css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) - def test_column(self): + def test_column_select(self): df = self.df self.assertEqual(self.testData, df.select("*").collect()) self.assertEqual(self.testData, df.select(df.key, df.value).collect()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f6178d775a38c..72145eb1cdf56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql +import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.reflect.ClassTag +import scala.collection.JavaConverters._ import scala.collection.JavaConversions._ import java.util.{ArrayList, List => JList} @@ -570,17 +572,17 @@ class DataFrame protected[sql]( //////////////////////////////////////////////////////////////////////////// // for Python API //////////////////////////////////////////////////////////////////////////// - private[sql] def select(cols: java.util.List[Column]): DataFrame = { - select(cols:_*) - } - private[sql] def groupby(cols: java.util.List[Column]): GroupedDataFrame = { - groupby(cols:_*) + /** + * A helpful function for Py4j, convert a list of Column to an array + */ + protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] = { + cols.toList.toArray } /** - * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + * Converts a JavaRDD to a PythonRDD. */ - private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { val fieldTypes = schema.fields.map(_.dataType) val jrdd = this.rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() SerDeUtil.javaToPython(jrdd) @@ -588,7 +590,7 @@ class DataFrame protected[sql]( /** * Serializes the Array[Row] returned by collect(), using the same format as javaToPython. */ - private[sql] def collectToPython: JList[Array[Byte]] = { + protected[sql] def collectToPython: JList[Array[Byte]] = { val fieldTypes = schema.fields.map(_.dataType) val pickle = new Pickler new ArrayList[Array[Byte]](collect().map { row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala index ef5ef6b09dc6c..2e1ef7cf976ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import scala.language.implicitConversions -import scala.collection.JavaConverters._ import scala.collection.JavaConversions._ import org.apache.spark.sql.catalyst.expressions._ @@ -76,6 +75,21 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi }.toSeq } + /** + * Compute aggregates by specifying a map from column name to aggregate methods. + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupby("department").agg(Map( + * "age" -> "max" + * "sum" -> "expense" + * )) + * }}} + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.toMap) + } + /** * Compute aggregates by specifying a series of aggregate columns. * The available aggregate methods are defined in [[org.apache.spark.sql.dsl]]. @@ -122,9 +136,4 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi * Compute the sum for each numeric columns for each group. */ override def sum(): DataFrame = aggregateNumericColumns(Sum) - - //// For Python API - private[sql] def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.toMap) - } }