Skip to content

Commit

Permalink
Python APIs.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Feb 10, 2015
1 parent c204967 commit 2bf44ef
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 5 deletions.
96 changes: 93 additions & 3 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,48 @@ def func(iterator):
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return DataFrame(df, self)

def load(self, path=None, dataSourceName=None, schema=None, **options):
"""Returns the dataset specified by the data source and a set of options
as a DataFrame. An optional schema can be applied as the schema of returned
DataFrame. If dataSourceName is not provided, the default data source configured
by spark.sql.sources.default will be used.
"""
if path is not None:
options["path"] = path
if dataSourceName is None:
dataSourceName = self._ssql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
joptions = MapConverter().convert(options,
self._sc._gateway._gateway_client)
if schema is None:
df = self._ssql_ctx.load(dataSourceName, joptions)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
df = self._ssql_ctx.load(dataSourceName, scala_datatype, joptions)
return DataFrame(df, self)

def createExternalTable(self, tableName, path=None, dataSourceName=None,
schema=None, **options):
"""Creates an external table based on the given data source and a set of options and
returns the corresponding DataFrame.
If dataSourceName is not provided, the default data source configured
by spark.sql.sources.default will be used.
"""
if path is not None:
options["path"] = path
if dataSourceName is None:
dataSourceName = self._ssql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
joptions = MapConverter().convert(options,
self._sc._gateway._gateway_client)
if schema is None:
df = self._ssql_ctx.createExternalTable(tableName, dataSourceName, joptions)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
df = self._ssql_ctx.createExternalTable(tableName, dataSourceName, scala_datatype,
joptions)
return DataFrame(df, self)

def sql(self, sqlQuery):
"""Return a L{DataFrame} representing the result of the given query.
Expand Down Expand Up @@ -1889,9 +1931,57 @@ def insertInto(self, tableName, overwrite=False):
"""
self._jdf.insertInto(tableName, overwrite)

def saveAsTable(self, tableName):
"""Creates a new table with the contents of this DataFrame."""
self._jdf.saveAsTable(tableName)
def saveAsTable(self, tableName, dataSourceName=None, mode="append", **options):
"""Creates a new table with the contents of this DataFrame based on the given data source
and a set of options. If a data source is not provided, the default data source configured
by spark.sql.sources.default will be used.
"""
if dataSourceName is None:
dataSourceName = self.sql_ctx._ssql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.ErrorIfExists
mode = mode.lower()
if mode == "append":
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Append
elif mode == "overwrite":
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Overwrite
elif mode == "ignore":
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Ignore
elif mode == "error":
pass
else:
raise ValueError(
"Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
joptions = MapConverter().convert(options,
self.sql_ctx._sc._gateway._gateway_client)
self._jdf.saveAsTable(tableName, dataSourceName, jmode, joptions)

def save(self, path=None, dataSourceName=None, mode="append", **options):
"""Saves the contents of the DataFrame to a data source based on the given data source,
the given save mode, and a set of options. If a data source is not provided,
the default data source configured by spark.sql.sources.default will be used.
"""
if path is not None:
options["path"] = path
if dataSourceName is None:
dataSourceName = self.sql_ctx._ssql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.ErrorIfExists
mode = mode.lower()
if mode == "append":
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Append
elif mode == "overwrite":
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Overwrite
elif mode == "ignore":
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Ignore
elif mode == "error":
pass
else:
raise ValueError(
"Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
joptions = MapConverter().convert(options,
self._sc._gateway._gateway_client)
self._jdf.save(dataSourceName, jmode, joptions)

def schema(self):
"""Returns the schema of this DataFrame (represented by
Expand Down
103 changes: 101 additions & 2 deletions python/pyspark/sql_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
else:
import unittest

from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
from pyspark.sql import SQLContext, HiveContext, IntegerType, Row, ArrayType, StructType,\
StructField, UserDefinedType, DoubleType
from pyspark.tests import ReusedPySparkTestCase


Expand Down Expand Up @@ -285,6 +285,38 @@ def test_aggregator(self):
self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])

def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.save(tmpPath, "org.apache.spark.sql.json", "error")
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))

from pyspark.sql import StructType, StructField, StringType
schema = StructType([StructField("value", StringType(), True)])
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))

df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))

df.save(dataSourceName="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
noUse="this options will not be used in save.")
actual = self.sqlCtx.load(dataSourceName="org.apache.spark.sql.json", path=tmpPath,
noUse="this options will not be used in load.")
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))

defaultDataSourceName = self.sqlCtx._ssql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
actual = self.sqlCtx.load(path=tmpPath)
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)

shutil.rmtree(tmpPath)

def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
Expand All @@ -294,6 +326,73 @@ def test_help_command(self):
pydoc.render_doc(df.foo)
pydoc.render_doc(df.take(1))

class HiveContextSQLTests(ReusedPySparkTestCase):

@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(cls.tempdir.name)
cls.sqlCtx = HiveContext(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData)
cls.df = cls.sqlCtx.inferSchema(rdd)

@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)

def test_save_and_load_table(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
"org.apache.spark.sql.json")
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
self.sqlCtx.sql("DROP TABLE externalJsonTable")

df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
from pyspark.sql import StructType, StructField, StringType
schema = StructType([StructField("value", StringType(), True)])
actual = self.sqlCtx.createExternalTable("externalJsonTable",
dataSourceName="org.apache.spark.sql.json",
schema=schema, path=tmpPath,
noUse="this options will not be used")
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
self.assertTrue(
sorted(df.select("value").collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
self.sqlCtx.sql("DROP TABLE savedJsonTable")
self.sqlCtx.sql("DROP TABLE externalJsonTable")

defaultDataSourceName = self.sqlCtx._ssql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
self.sqlCtx.sql("DROP TABLE savedJsonTable")
self.sqlCtx.sql("DROP TABLE externalJsonTable")
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)

shutil.rmtree(tmpPath)

if __name__ == "__main__":
unittest.main()

0 comments on commit 2bf44ef

Please sign in to comment.