Skip to content

Commit

Permalink
createDataFrame from RDD with columns
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Feb 10, 2015
1 parent bd0b5ea commit 9526e97
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 12 deletions.
66 changes: 64 additions & 2 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
from pyspark.sql.types import StringType, StructType, _verify_type, \
from pyspark.sql.types import StringType, StructType, _infer_type, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame

Expand Down Expand Up @@ -118,6 +118,9 @@ def registerFunction(self, name, f, returnType=StringType()):
def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
::note:
Deprecated in 1.3, use :func:`createDataFrame` instead
When samplingRatio is specified, the schema is inferred by looking
at the types of each row in the sampled dataset. Otherwise, the
first 100 rows of the RDD are inspected. Nested collections are
Expand Down Expand Up @@ -186,7 +189,7 @@ def inferSchema(self, rdd, samplingRatio=None):
warnings.warn("Some of types cannot be determined by the "
"first 100 rows, please try again with sampling")
else:
if samplingRatio > 0.99:
if samplingRatio < 0.99:
rdd = rdd.sample(False, float(samplingRatio))
schema = rdd.map(_infer_schema).reduce(_merge_type)

Expand All @@ -198,6 +201,9 @@ def applySchema(self, rdd, schema):
"""
Applies the given schema to the given RDD of L{tuple} or L{list}.
::note:
Deprecated in 1.3, use :func:`createDataFrame` instead
These tuples or lists can contain complex nested structures like
lists, maps or nested rows.
Expand Down Expand Up @@ -287,6 +293,62 @@ def applySchema(self, rdd, schema):
df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return DataFrame(df, self)

def createDataFrame(self, rdd, schema=None, samplingRatio=None):
"""
Create a DataFrame from an RDD of tuple/list and an optional `schema`.
`schema` could be :class:`StructType` or a list of column names.
When `schema` is a list of column names, the type of each column
will be inferred from `rdd`.
When `schema` is None, it will try to infer the column name and type
from `rdd`, which should be an RDD of :class:`Row`, or namedtuple,
or dict.
If referring needed, `samplingRatio` is used to determined how many
rows will be used to do referring. The first row will be used if
`samplingRatio` is None.
:param rdd: an RDD of Row or tuple or list or dict
:param schema: a StructType or list of names of columns
:param samplingRatio: the sample ratio of rows used for inferring
:return: a DataFrame
>>> rdd = sc.parallelize([('Alice', 1)])
>>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
>>> df.collect()
[Row(name=u'Alice', age=1)]
>>> from pyspark.sql import Row
>>> Person = Row('name', 'age')
>>> person = rdd.map(lambda r: Person(*r))
>>> df2 = sqlCtx.createDataFrame(person)
>>> df2.collect()
[Row(name=u'Alice', age=1)]
>>> from pyspark.sql.types import *
>>> schema = StructType([
... StructField("name", StringType(), True),
... StructField("age", IntegerType(), True)])
>>> df3 = sqlCtx.createDataFrame(rdd, schema)
>>> df3.collect()
[Row(name=u'Alice', age=1)]
"""
if isinstance(rdd, DataFrame):
raise TypeError("rdd is already a DataFrame")

if isinstance(schema, StructType):
return self.applySchema(rdd, schema)
else:
if isinstance(schema, (list, tuple)):
first = rdd.first()
if not isinstance(first, (list, tuple)):
raise ValueError("each row in `rdd` should be list or tuple")
row_cls = Row(*schema)
rdd = rdd.map(lambda r: row_cls(*r))
return self.inferSchema(rdd, samplingRatio)

def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
Expand Down
154 changes: 144 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* val people =
* sc.textFile("examples/src/main/resources/people.txt").map(
* _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
* val dataFrame = sqlContext. applySchema(people, schema)
* val dataFrame = sqlContext.createDataFrame(people, schema)
* dataFrame.printSchema
* // root
* // |-- name: string (nullable = false)
Expand All @@ -252,20 +252,90 @@ class SQLContext(@transient val sparkContext: SparkContext)
* dataFrame.registerTempTable("people")
* sqlContext.sql("select name from people").collect.foreach(println)
* }}}
*
* @group userf
*/
@DeveloperApi
def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
DataFrame(this, logicalPlan)
}

@DeveloperApi
def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
applySchema(rowRDD.rdd, schema);
def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
createDataFrame(rowRDD.rdd, schema)
}

/**
* Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying
* a seq of names of columns to this RDD, the data type for each column will
* be inferred by the first row.
*
* @param rowRDD an RDD of Row
* @param columns names for each column
* @return DataFrame
*/
def createDataFrame(rowRDD: RDD[Row], columns: Seq[String]): DataFrame = {
def inferType(obj: Any): DataType = obj match {
case null => NullType
case _: Int => IntegerType
case _: java.lang.Integer => IntegerType
case _: String => StringType
case _: Double => DoubleType
case _: java.lang.Float => FloatType
case _: Float => FloatType
case _: Byte => ByteType
case _: java.lang.Byte => ByteType
case _: Boolean => BooleanType
case _: java.lang.Boolean => BooleanType
case _: java.math.BigDecimal => DecimalType()
case _: java.sql.Date => DateType
case _: java.util.Calendar => TimestampType
case _: java.sql.Timestamp => TimestampType
case map: Map[_, _] =>
if (map.isEmpty) {
throw new Exception("Cannot infer type from empty Map")
}
val (k, v) = map.head
MapType(inferType(k), inferType(v), true)
case map: java.util.Map[_, _] =>
if (map.isEmpty) {
throw new Exception("Cannot infer type from empty Map")
}
val (k, v) = map.head
MapType(inferType(k), inferType(v), true)
case seq: Seq[Any] =>
if (seq.isEmpty) {
throw new Exception("Cannot infer type from empty seq")
}
ArrayType(inferType(seq.head), true)
case arr: Array[Any] =>
if (arr.isEmpty) {
throw new Exception("Cannot infer type from empty array")
}
ArrayType(inferType(arr.head), true)
case other =>
throw new Exception(s"Cannot infer type from $other")
}

val first = rowRDD.first()
val types = first.toSeq.map(inferType)
val fields = columns.zip(types).map(x => new StructField(x._1, x._2, true))
val schema = StructType(fields)
createDataFrame(rowRDD, schema)
}

/**
* Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying
* a seq of names of columns to this RDD, the data type for each column will
* be inferred by the first row.
*
* @param rowRDD an JavaRDD of Row
* @param columns names for each column
* @return DataFrame
*/
def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = {
createDataFrame(rowRDD.rdd, columns.toSeq)
}

/**
Expand All @@ -274,7 +344,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
val attributeSeq = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
Expand All @@ -301,8 +371,72 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
createDataFrame(rdd.rdd, beanClass)
}

/**
* :: DeveloperApi ::
* Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
* It is important to make sure that the structure of every [[Row]] of the provided RDD matches
* the provided schema. Otherwise, there will be runtime exception.
* Example:
* {{{
* import org.apache.spark.sql._
* val sqlContext = new org.apache.spark.sql.SQLContext(sc)
*
* val schema =
* StructType(
* StructField("name", StringType, false) ::
* StructField("age", IntegerType, true) :: Nil)
*
* val people =
* sc.textFile("examples/src/main/resources/people.txt").map(
* _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
* val dataFrame = sqlContext. applySchema(people, schema)
* dataFrame.printSchema
* // root
* // |-- name: string (nullable = false)
* // |-- age: integer (nullable = true)
*
* dataFrame.registerTempTable("people")
* sqlContext.sql("select name from people").collect.foreach(println)
* }}}
*
* @group userf
*/
@DeveloperApi
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
createDataFrame(rowRDD, schema)
}

@DeveloperApi
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
createDataFrame(rowRDD, schema)
}

/**
* Applies a schema to an RDD of Java Beans.
*
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
createDataFrame(rdd, beanClass)
}

/**
* Applies a schema to an RDD of Java Beans.
*
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
applySchema(rdd.rdd, beanClass)
createDataFrame(rdd, beanClass)
}

/**
Expand Down Expand Up @@ -375,7 +509,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
applySchema(rowRDD, appliedSchema)
createDataFrame(rowRDD, appliedSchema)
}

@Experimental
Expand All @@ -393,7 +527,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
applySchema(rowRDD, appliedSchema)
createDataFrame(rowRDD, appliedSchema)
}

@Experimental
Expand Down

0 comments on commit 9526e97

Please sign in to comment.