From 6ada4f6f52cf1d992c7ab0c32318790cf08b0a0d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 8 Apr 2015 13:31:45 -0700 Subject: [PATCH] [SPARK-6781] [SQL] use sqlContext in python shell Use `sqlContext` in PySpark shell, make it consistent with SQL programming guide. `sqlCtx` is also kept for compatibility. Author: Davies Liu Closes #5425 from davies/sqlCtx and squashes the following commits: af67340 [Davies Liu] sqlCtx -> sqlContext 15a278f [Davies Liu] use sqlContext in python shell --- docs/ml-guide.md | 2 +- docs/sql-programming-guide.md | 4 +- .../spark/examples/sql/JavaSparkSQL.java | 20 ++--- .../ml/simple_text_classification_pipeline.py | 2 +- .../src/main/python/mllib/dataset_example.py | 6 +- python/pyspark/ml/classification.py | 4 +- python/pyspark/ml/feature.py | 4 +- python/pyspark/shell.py | 6 +- python/pyspark/sql/context.py | 79 +++++++++---------- python/pyspark/sql/dataframe.py | 6 +- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/types.py | 4 +- 12 files changed, 69 insertions(+), 70 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c08c76d226713..771a07183e26f 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -493,7 +493,7 @@ from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.sql import Row, SQLContext sc = SparkContext(appName="SimpleTextClassificationPipeline") -sqlCtx = SQLContext(sc) +sqlContext = SQLContext(sc) # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4441d6a000a02..663f656883721 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1642,7 +1642,7 @@ moved into the udf object in `SQLContext`.
{% highlight java %} -sqlCtx.udf.register("strLen", (s: String) => s.length()) +sqlContext.udf.register("strLen", (s: String) => s.length()) {% endhighlight %}
@@ -1650,7 +1650,7 @@ sqlCtx.udf.register("strLen", (s: String) => s.length())
{% highlight java %} -sqlCtx.udf().register("strLen", (String s) -> { s.length(); }); +sqlContext.udf().register("strLen", (String s) -> { s.length(); }); {% endhighlight %}
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index dee794840a3e1..8159ffbe2d269 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -55,7 +55,7 @@ public void setAge(int age) { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); - SQLContext sqlCtx = new SQLContext(ctx); + SQLContext sqlContext = new SQLContext(ctx); System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. @@ -74,11 +74,11 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - DataFrame schemaPeople = sqlCtx.createDataFrame(people, Person.class); + DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -99,12 +99,12 @@ public String call(Row row) { // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); + DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers2 = - sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override public String call(Row row) { @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); + DataFrame peopleFromJsonFile = sqlContext.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -133,8 +133,8 @@ public String call(Row row) { // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); - // SQL statements can be run by using the sql methods provided by sqlCtx. - DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + // SQL statements can be run by using the sql methods provided by sqlContext. + DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index d281f4fa44282..c73edb7fd6b20 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -33,7 +33,7 @@ if __name__ == "__main__": sc = SparkContext(appName="SimpleTextClassificationPipeline") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index b5a70db2b9a3c..fcbf56cbf0c52 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -44,19 +44,19 @@ def summarize(dataset): print >> sys.stderr, "Usage: dataset_example.py " exit(-1) sc = SparkContext(appName="DatasetExample") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() summarize(dataset0) tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) print "Save dataset as a Parquet file to %s." % tempdir dataset0.saveAsParquetFile(tempdir) print "Load it back and summarize it again." - dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() summarize(dataset1) shutil.rmtree(tempdir) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4ff7463498cce..7f42de531f3b4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -91,9 +91,9 @@ class LogisticRegressionModel(JavaModel): # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 433b4fb5d22bf..1cfcd019dfb18 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -117,9 +117,9 @@ def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 1a02fece9c5a5..81aa970a32f76 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -53,9 +53,9 @@ try: # Try to access HiveConf, it will raise exception if Hive is not added sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - sqlCtx = HiveContext(sc) + sqlCtx = sqlContext = HiveContext(sc) except py4j.protocol.Py4JError: - sqlCtx = SQLContext(sc) + sqlCtx = sqlContext = SQLContext(sc) print("""Welcome to ____ __ @@ -68,7 +68,7 @@ platform.python_version(), platform.python_build()[0], platform.python_build()[1])) -print("SparkContext available as sc, %s available as sqlCtx." % sqlCtx.__class__.__name__) +print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) if add_files is not None: print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c2d81ba804110..93e2d176a5b6f 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -37,12 +37,12 @@ __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] -def _monkey_patch_RDD(sqlCtx): +def _monkey_patch_RDD(sqlContext): def toDF(self, schema=None, sampleRatio=None): """ Converts current :class:`RDD` into a :class:`DataFrame` - This is a shorthand for ``sqlCtx.createDataFrame(rdd, schema, sampleRatio)`` + This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)`` :param schema: a StructType or list of names of columns :param samplingRatio: the sample ratio of rows used for inferring @@ -51,7 +51,7 @@ def toDF(self, schema=None, sampleRatio=None): >>> rdd.toDF().collect() [Row(name=u'Alice', age=1)] """ - return sqlCtx.createDataFrame(self, schema, sampleRatio) + return sqlContext.createDataFrame(self, schema, sampleRatio) RDD.toDF = toDF @@ -75,13 +75,13 @@ def __init__(self, sparkContext, sqlContext=None): """Creates a new SQLContext. >>> from datetime import datetime - >>> sqlCtx = SQLContext(sc) + >>> sqlContext = SQLContext(sc) >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> df = allTypes.toDF() >>> df.registerTempTable("allTypes") - >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, @@ -133,18 +133,18 @@ def registerFunction(self, name, f, returnType=StringType()): :param samplingRatio: lambda function :param returnType: a :class:`DataType` object - >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlContext.sql("SELECT stringLengthString('test')").collect() [Row(c0=u'4')] >>> from pyspark.sql.types import IntegerType - >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] >>> from pyspark.sql.types import IntegerType - >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) @@ -229,26 +229,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): :param samplingRatio: the sample ratio of rows used for inferring >>> l = [('Alice', 1)] - >>> sqlCtx.createDataFrame(l).collect() + >>> sqlContext.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] - >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect() + >>> sqlContext.createDataFrame(l, ['name', 'age']).collect() [Row(name=u'Alice', age=1)] >>> d = [{'name': 'Alice', 'age': 1}] - >>> sqlCtx.createDataFrame(d).collect() + >>> sqlContext.createDataFrame(d).collect() [Row(age=1, name=u'Alice')] >>> rdd = sc.parallelize(l) - >>> sqlCtx.createDataFrame(rdd).collect() + >>> sqlContext.createDataFrame(rdd).collect() [Row(_1=u'Alice', _2=1)] - >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age']) + >>> df = sqlContext.createDataFrame(rdd, ['name', 'age']) >>> df.collect() [Row(name=u'Alice', age=1)] >>> from pyspark.sql import Row >>> Person = Row('name', 'age') >>> person = rdd.map(lambda r: Person(*r)) - >>> df2 = sqlCtx.createDataFrame(person) + >>> df2 = sqlContext.createDataFrame(person) >>> df2.collect() [Row(name=u'Alice', age=1)] @@ -256,11 +256,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): >>> schema = StructType([ ... StructField("name", StringType(), True), ... StructField("age", IntegerType(), True)]) - >>> df3 = sqlCtx.createDataFrame(rdd, schema) + >>> df3 = sqlContext.createDataFrame(rdd, schema) >>> df3.collect() [Row(name=u'Alice', age=1)] - >>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] """ if isinstance(data, DataFrame): @@ -316,7 +316,7 @@ def registerDataFrameAsTable(self, df, tableName): Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") """ if (df.__class__ is DataFrame): self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) @@ -330,7 +330,7 @@ def parquetFile(self, *paths): >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -352,7 +352,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): >>> shutil.rmtree(jsonFile) >>> with open(jsonFile, 'w') as f: ... f.writelines(jsonStrings) - >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> df1 = sqlContext.jsonFile(jsonFile) >>> df1.printSchema() root |-- field1: long (nullable = true) @@ -365,7 +365,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): ... StructField("field2", StringType()), ... StructField("field3", ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlCtx.jsonFile(jsonFile, schema) + >>> df2 = sqlContext.jsonFile(jsonFile, schema) >>> df2.printSchema() root |-- field2: string (nullable = true) @@ -386,11 +386,11 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): If the schema is provided, applies the given schema to this JSON dataset. Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. - >>> df1 = sqlCtx.jsonRDD(json) + >>> df1 = sqlContext.jsonRDD(json) >>> df1.first() Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) - >>> df2 = sqlCtx.jsonRDD(json, df1.schema) + >>> df2 = sqlContext.jsonRDD(json, df1.schema) >>> df2.first() Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) @@ -400,7 +400,7 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): ... StructField("field3", ... StructType([StructField("field5", ArrayType(IntegerType()))])) ... ]) - >>> df3 = sqlCtx.jsonRDD(json, schema) + >>> df3 = sqlContext.jsonRDD(json, schema) >>> df3.first() Row(field2=u'row1', field3=Row(field5=None)) """ @@ -480,8 +480,8 @@ def createExternalTable(self, tableName, path=None, source=None, def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ @@ -490,8 +490,8 @@ def sql(self, sqlQuery): def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.table("table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -505,8 +505,8 @@ def tables(self, dbName=None): The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` (a column with :class:`BooleanType` indicating if a table is a temporary one or not). - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.tables() + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() Row(tableName=u'table1', isTemporary=True) """ @@ -520,10 +520,10 @@ def tableNames(self, dbName=None): If ``dbName`` is not specified, the current database will be used. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> "table1" in sqlCtx.tableNames() + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> "table1" in sqlContext.tableNames() True - >>> "table1" in sqlCtx.tableNames("db") + >>> "table1" in sqlContext.tableNames("db") True """ if dbName is None: @@ -578,11 +578,11 @@ def _get_hive_ctx(self): class UDFRegistration(object): """Wrapper for user-defined function registration.""" - def __init__(self, sqlCtx): - self.sqlCtx = sqlCtx + def __init__(self, sqlContext): + self.sqlContext = sqlContext def register(self, name, f, returnType=StringType()): - return self.sqlCtx.registerFunction(name, f, returnType) + return self.sqlContext.registerFunction(name, f, returnType) register.__doc__ = SQLContext.registerFunction.__doc__ @@ -595,13 +595,12 @@ def _test(): globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['rdd'] = rdd = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) - _monkey_patch_RDD(sqlCtx) globs['df'] = rdd.toDF() jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c30326ebd133e..ef91a9c4f522d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -110,7 +110,7 @@ def saveAsParquetFile(self, path): >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) >>> sorted(df2.collect()) == sorted(df.collect()) True """ @@ -123,7 +123,7 @@ def registerTempTable(self, name): that was used to create this :class:`DataFrame`. >>> df.registerTempTable("people") - >>> df2 = sqlCtx.sql("select * from people") + >>> df2 = sqlContext.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -1180,7 +1180,7 @@ def _test(): globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 146ba6f3e0d98..daeb6916b58bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -161,7 +161,7 @@ def _test(): globs = pyspark.sql.functions.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 45eb8b945dcb0..7e0124b13671b 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -434,7 +434,7 @@ def _parse_datatype_json_string(json_string): >>> def check_datatype(datatype): ... pickled = pickle.loads(pickle.dumps(datatype)) ... assert datatype == pickled - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype >>> for cls in _all_primitive_types.values(): @@ -1237,7 +1237,7 @@ def _test(): globs = pyspark.sql.types.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['ExamplePoint'] = ExamplePoint globs['ExamplePointUDT'] = ExamplePointUDT (failure_count, test_count) = doctest.testmod(