diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index 03fedd01016b9..f723cd6b9dfab 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -507,7 +507,6 @@ v_N
This example below demonstrates how to load a simple vectors file, extract a set of vectors, then transform those vectors using a transforming vector value.
-
{% highlight scala %}
@@ -531,3 +530,57 @@ val transformedData2 = parsedData.map(x => transformer.transform(x))
+## PCA
+
+A feature transformer that projects vectors to a low-dimensional space using PCA.
+Details you can read at [dimensionality reduction](mllib-dimensionality-reduction.html).
+
+### Example
+
+The following code demonstrates how to compute principal components on a `Vector`
+and use them to project the vectors into a low-dimensional space while keeping associated labels
+for calculation a [Linear Regression]((mllib-linear-methods.html))
+
+
+
+{% highlight scala %}
+import org.apache.spark.mllib.regression.LinearRegressionWithSGD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.feature.PCA
+
+val data = sc.textFile("data/mllib/ridge-data/lpsa.data").map { line =>
+ val parts = line.split(',')
+ LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
+}.cache()
+
+val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
+val training = splits(0).cache()
+val test = splits(1)
+
+val pca = new PCA(training.first().features.size/2).fit(data.map(_.features))
+val training_pca = training.map(p => p.copy(features = pca.transform(p.features)))
+val test_pca = test.map(p => p.copy(features = pca.transform(p.features)))
+
+val numIterations = 100
+val model = LinearRegressionWithSGD.train(training, numIterations)
+val model_pca = LinearRegressionWithSGD.train(training_pca, numIterations)
+
+val valuesAndPreds = test.map { point =>
+ val score = model.predict(point.features)
+ (score, point.label)
+}
+
+val valuesAndPreds_pca = test_pca.map { point =>
+ val score = model_pca.predict(point.features)
+ (score, point.label)
+}
+
+val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean()
+val MSE_pca = valuesAndPreds_pca.map{case(v, p) => math.pow((v - p), 2)}.mean()
+
+println("Mean Squared Error = " + MSE)
+println("PCA Mean Squared Error = " + MSE_pca)
+{% endhighlight %}
+
+
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
new file mode 100644
index 0000000000000..4e01e402b4283
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.feature
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.linalg.distributed.RowMatrix
+import org.apache.spark.rdd.RDD
+
+/**
+ * A feature transformer that projects vectors to a low-dimensional space using PCA.
+ *
+ * @param k number of principal components
+ */
+class PCA(val k: Int) {
+ require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k")
+
+ /**
+ * Computes a [[PCAModel]] that contains the principal components of the input vectors.
+ *
+ * @param sources source vectors
+ */
+ def fit(sources: RDD[Vector]): PCAModel = {
+ require(k <= sources.first().size,
+ s"source vector size is ${sources.first().size} must be greater than k=$k")
+
+ val mat = new RowMatrix(sources)
+ val pc = mat.computePrincipalComponents(k) match {
+ case dm: DenseMatrix =>
+ dm
+ case sm: SparseMatrix =>
+ /* Convert a sparse matrix to dense.
+ *
+ * RowMatrix.computePrincipalComponents always returns a dense matrix.
+ * The following code is a safeguard.
+ */
+ sm.toDense
+ case m =>
+ throw new IllegalArgumentException("Unsupported matrix format. Expected " +
+ s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}")
+
+ }
+ new PCAModel(k, pc)
+ }
+
+ /** Java-friendly version of [[fit()]] */
+ def fit(sources: JavaRDD[Vector]): PCAModel = fit(sources.rdd)
+}
+
+/**
+ * Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA.
+ *
+ * @param k number of principal components.
+ * @param pc a principal components Matrix. Each column is one principal component.
+ */
+class PCAModel private[mllib] (val k: Int, val pc: DenseMatrix) extends VectorTransformer {
+ /**
+ * Transform a vector by computed Principal Components.
+ *
+ * @param vector vector to be transformed.
+ * Vector must be the same length as the source vectors given to [[PCA.fit()]].
+ * @return transformed vector. Vector will be of length k.
+ */
+ override def transform(vector: Vector): Vector = {
+ vector match {
+ case dv: DenseVector =>
+ pc.transpose.multiply(dv)
+ case SparseVector(size, indices, values) =>
+ /* SparseVector -> single row SparseMatrix */
+ val sm = Matrices.sparse(size, 1, Array(0, indices.length), indices, values).transpose
+ val projection = sm.multiply(pc)
+ Vectors.dense(projection.values)
+ case _ =>
+ throw new IllegalArgumentException("Unsupported vector format. Expected " +
+ s"SparseVector or DenseVector. Instead got: ${vector.getClass}")
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
new file mode 100644
index 0000000000000..758af588f1c69
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.distributed.RowMatrix
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class PCASuite extends FunSuite with MLlibTestSparkContext {
+
+ private val data = Array(
+ Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
+ Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+ Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+ )
+
+ private lazy val dataRDD = sc.parallelize(data, 2)
+
+ test("Correct computing use a PCA wrapper") {
+ val k = dataRDD.count().toInt
+ val pca = new PCA(k).fit(dataRDD)
+
+ val mat = new RowMatrix(dataRDD)
+ val pc = mat.computePrincipalComponents(k)
+
+ val pca_transform = pca.transform(dataRDD).collect()
+ val mat_multiply = mat.multiply(pc).rows.collect()
+
+ assert(pca_transform.toSet === mat_multiply.toSet)
+ }
+}
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index ed3171b6976d3..3be0979b92013 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -88,12 +88,12 @@ def get$Name(self):
print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n")
print("from pyspark.ml.param import Param, Params\n\n")
shared = [
- ("maxIter", "max number of iterations", None),
- ("regParam", "regularization constant", None),
+ ("maxIter", "max number of iterations (>= 0)", None),
+ ("regParam", "regularization parameter (>= 0)", None),
("featuresCol", "features column name", "'features'"),
("labelCol", "label column name", "'label'"),
("predictionCol", "prediction column name", "'prediction'"),
- ("rawPredictionCol", "raw prediction column name", "'rawPrediction'"),
+ ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", "'rawPrediction'"),
("inputCol", "input column name", None),
("inputCols", "input column names", None),
("outputCol", "output column name", None),
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index d0bcadee22347..4b22322b895b4 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -22,16 +22,16 @@
class HasMaxIter(Params):
"""
- Mixin for param maxIter: max number of iterations.
+ Mixin for param maxIter: max number of iterations (>= 0).
"""
# a placeholder to make it appear in the generated doc
- maxIter = Param(Params._dummy(), "maxIter", "max number of iterations")
+ maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0)")
def __init__(self):
super(HasMaxIter, self).__init__()
- #: param for max number of iterations
- self.maxIter = Param(self, "maxIter", "max number of iterations")
+ #: param for max number of iterations (>= 0)
+ self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0)")
if None is not None:
self._setDefault(maxIter=None)
@@ -51,16 +51,16 @@ def getMaxIter(self):
class HasRegParam(Params):
"""
- Mixin for param regParam: regularization constant.
+ Mixin for param regParam: regularization parameter (>= 0).
"""
# a placeholder to make it appear in the generated doc
- regParam = Param(Params._dummy(), "regParam", "regularization constant")
+ regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0)")
def __init__(self):
super(HasRegParam, self).__init__()
- #: param for regularization constant
- self.regParam = Param(self, "regParam", "regularization constant")
+ #: param for regularization parameter (>= 0)
+ self.regParam = Param(self, "regParam", "regularization parameter (>= 0)")
if None is not None:
self._setDefault(regParam=None)
@@ -167,16 +167,16 @@ def getPredictionCol(self):
class HasRawPredictionCol(Params):
"""
- Mixin for param rawPredictionCol: raw prediction column name.
+ Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name.
"""
# a placeholder to make it appear in the generated doc
- rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction column name")
+ rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name")
def __init__(self):
super(HasRawPredictionCol, self).__init__()
- #: param for raw prediction column name
- self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction column name")
+ #: param for raw prediction (a.k.a. confidence) column name
+ self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name")
if 'rawPrediction' is not None:
self._setDefault(rawPredictionCol='rawPrediction')
@@ -403,14 +403,12 @@ class HasStepSize(Params):
"""
# a placeholder to make it appear in the generated doc
- stepSize = Param(Params._dummy(), "stepSize",
- "Step size to be used for each iteration of optimization.")
+ stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.")
def __init__(self):
super(HasStepSize, self).__init__()
#: param for Step size to be used for each iteration of optimization.
- self.stepSize = Param(self, "stepSize",
- "Step size to be used for each iteration of optimization.")
+ self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.")
if None is not None:
self._setDefault(stepSize=None)
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index c1b2077c985cf..fdbae06405f6a 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -179,7 +179,7 @@ def transform(self, dataset, params={}):
return dataset
-class Evaluator(object):
+class Evaluator(Params):
"""
Base class for evaluators that compute metrics from predictions.
"""
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 3a42bcf723894..ba6478dcd58a9 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -34,7 +34,7 @@
from pyspark.sql import DataFrame
from pyspark.ml.param import Param
from pyspark.ml.param.shared import HasMaxIter, HasInputCol
-from pyspark.ml.pipeline import Transformer, Estimator, Pipeline
+from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer
class MockDataset(DataFrame):
@@ -77,7 +77,7 @@ def fit(self, dataset, params={}):
return model
-class MockModel(MockTransformer, Transformer):
+class MockModel(MockTransformer, Model):
def __init__(self):
super(MockModel, self).__init__()
@@ -128,7 +128,7 @@ def test_param(self):
testParams = TestParams()
maxIter = testParams.maxIter
self.assertEqual(maxIter.name, "maxIter")
- self.assertEqual(maxIter.doc, "max number of iterations")
+ self.assertEqual(maxIter.doc, "max number of iterations (>= 0)")
self.assertTrue(maxIter.parent is testParams)
def test_params(self):
@@ -156,7 +156,7 @@ def test_params(self):
self.assertEquals(
testParams.explainParams(),
"\n".join(["inputCol: input column name (undefined)",
- "maxIter: max number of iterations (default: 10, current: 100)"]))
+ "maxIter: max number of iterations (>= 0) (default: 10, current: 100)"]))
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 28e3727f2c064..86f4dc7368be0 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -236,6 +236,7 @@ class CrossValidatorModel(Model):
"""
def __init__(self, bestModel):
+ super(CrossValidatorModel, self).__init__()
#: best model from cross validation
self.bestModel = bestModel
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 1a5083dbe0f61..a03ade3881f59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -109,7 +109,7 @@ private[sql] object JDBCRDD extends Logging {
val fields = new Array[StructField](ncols)
var i = 0
while (i < ncols) {
- val columnName = rsmd.getColumnName(i + 1)
+ val columnName = rsmd.getColumnLabel(i + 1)
val dataType = rsmd.getColumnType(i + 1)
val typeName = rsmd.getColumnTypeName(i + 1)
val fieldSize = rsmd.getPrecision(i + 1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 021affafe36a6..2abfe7f167f77 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -204,6 +204,22 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
assert(ids(2) === 3)
}
+ test("Register JDBC query with renamed fields") {
+ // Regression test for bug SPARK-7345
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE renamed
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable '(select NAME as NAME1, NAME as NAME2 from TEST.PEOPLE)',
+ |user 'testUser', password 'testPass')
+ """.stripMargin.replaceAll("\n", " "))
+
+ val df = sql("SELECT * FROM renamed")
+ assert(df.schema.fields.size == 2)
+ assert(df.schema.fields(0).name == "NAME1")
+ assert(df.schema.fields(1).name == "NAME2")
+ }
+
test("Basic API") {
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect().size === 3)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 20a23b3bd6aa9..54f2f3cdec298 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.util.Utils
class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
- import caseInsensisitiveContext._
+ import caseInsensitiveContext._
var path: File = null
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
index ca25751b9583d..6664e8d64c13a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
@@ -64,7 +64,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
}
class DDLTestSuite extends DataSourceTest {
- import caseInsensisitiveContext._
+ import caseInsensitiveContext._
before {
sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 9d3090c19b4e8..24ed665c67d2e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter
abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
// We want to test some edge cases.
- implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext)
+ implicit val caseInsensitiveContext = new SQLContext(TestSQLContext.sparkContext)
- caseInsensisitiveContext.setConf(SQLConf.CASE_SENSITIVE, "false")
+ caseInsensitiveContext.setConf(SQLConf.CASE_SENSITIVE, "false")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index cb5e5147ff189..cce747e7dbf64 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -97,7 +97,7 @@ object FiltersPushed {
class FilteredScanSuite extends DataSourceTest {
- import caseInsensisitiveContext._
+ import caseInsensitiveContext._
before {
sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 50629ea4dc066..d1d427e1790bd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.util.Utils
class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
- import caseInsensisitiveContext._
+ import caseInsensitiveContext._
var path: File = null
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index 6a1ddf2f8e98b..c2bc52e2120c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -52,7 +52,7 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo
}
class PrunedScanSuite extends DataSourceTest {
- import caseInsensisitiveContext._
+ import caseInsensitiveContext._
before {
sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
index cb287ba85c1f8..6567d1acd7644 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.util.Utils
class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
- import caseInsensisitiveContext._
+ import caseInsensitiveContext._
var originalDefaultSource: String = null
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 3b47b8adf313b..77af04a491742 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -88,7 +88,7 @@ case class AllDataTypesScan(
}
class TableScanSuite extends DataSourceTest {
- import caseInsensisitiveContext._
+ import caseInsensitiveContext._
var tableWithSchemaExpected = (1 to 10).map { i =>
Row(
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
index 0be5a92c2546c..3458b04bfba0f 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -147,7 +147,7 @@ object HiveThriftServer2 extends Logging {
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
server.stop()
}
-
+ var onlineSessionNum: Int = 0
val sessionList = new mutable.LinkedHashMap[String, SessionInfo]
val executionList = new mutable.LinkedHashMap[String, ExecutionInfo]
val retainedStatements =
@@ -170,11 +170,13 @@ object HiveThriftServer2 extends Logging {
def onSessionCreated(ip: String, sessionId: String, userName: String = "UNKNOWN"): Unit = {
val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName)
sessionList.put(sessionId, info)
+ onlineSessionNum += 1
trimSessionIfNecessary()
}
def onSessionClosed(sessionId: String): Unit = {
sessionList(sessionId).finishTimestamp = System.currentTimeMillis
+ onlineSessionNum -= 1
}
def onStatementStart(
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala
index 71b16b6bebffb..6a2be4a58e5cb 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala
@@ -29,7 +29,7 @@ import org.apache.spark.ui.UIUtils._
import org.apache.spark.ui._
-/** Page for Spark Web UI that shows statistics of a streaming job */
+/** Page for Spark Web UI that shows statistics of a thrift server */
private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("") with Logging {
private val listener = parent.listener
@@ -42,7 +42,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
generateBasicStats() ++
++
- {listener.sessionList.size} session(s) are online,
+ {listener.onlineSessionNum} session(s) are online,
running {listener.totalRunning} SQL statement(s)
++
generateSessionStatsTable() ++
@@ -50,12 +50,12 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000))
}
- /** Generate basic stats of the streaming program */
+ /** Generate basic stats of the thrift server program */
private def generateBasicStats(): Seq[Node] = {
val timeSinceStart = System.currentTimeMillis() - startTime.getTime
-
- Started at: {startTime.toString}
+ Started at: {formatDate(startTime)}
-
Time since start: {formatDurationVerbose(timeSinceStart)}
@@ -148,7 +148,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
{session.userName} |
{session.ip} |
- {session.sessionId} | ,
+ {session.sessionId} |
{formatDate(session.startTimestamp)} |
{if(session.finishTimestamp > 0) formatDate(session.finishTimestamp)} |
{formatDurationOption(Some(session.totalTime))} |
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala
index 42c49678d24f0..92cfd7d40338c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala
@@ -63,6 +63,11 @@ case class Time(private val millis: Long) {
new Time((this.millis / t) * t)
}
+ def floor(that: Duration, zeroTime: Time): Time = {
+ val t = that.milliseconds
+ new Time(((this.millis - zeroTime.milliseconds) / t) * t + zeroTime.milliseconds)
+ }
+
def isMultipleOf(that: Duration): Boolean =
(this.millis % that.milliseconds == 0)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index f1f8a70655996..7092a3d3f0b86 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -763,16 +763,22 @@ abstract class DStream[T: ClassTag] (
if (!isInitialized) {
throw new SparkException(this + " has not been initialized")
}
- if (!(fromTime - zeroTime).isMultipleOf(slideDuration)) {
- logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration ("
- + slideDuration + ")")
+
+ val alignedToTime = if ((toTime - zeroTime).isMultipleOf(slideDuration)) {
+ toTime
+ } else {
+ logWarning("toTime (" + toTime + ") is not a multiple of slideDuration ("
+ + slideDuration + ")")
+ toTime.floor(slideDuration, zeroTime)
}
- if (!(toTime - zeroTime).isMultipleOf(slideDuration)) {
- logWarning("toTime (" + fromTime + ") is not a multiple of slideDuration ("
- + slideDuration + ")")
+
+ val alignedFromTime = if ((fromTime - zeroTime).isMultipleOf(slideDuration)) {
+ fromTime
+ } else {
+ logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration ("
+ + slideDuration + ")")
+ fromTime.floor(slideDuration, zeroTime)
}
- val alignedToTime = toTime.floor(slideDuration)
- val alignedFromTime = fromTime.floor(slideDuration)
logInfo("Slicing from " + fromTime + " to " + toTime +
" (aligned to " + alignedFromTime + " and " + alignedToTime + ")")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala
index 5579ac364346c..e6a01656f479d 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala
@@ -69,6 +69,9 @@ class TimeSuite extends TestSuiteBase {
assert(new Time(1200).floor(new Duration(200)) == new Time(1200))
assert(new Time(199).floor(new Duration(200)) == new Time(0))
assert(new Time(1).floor(new Duration(1)) == new Time(1))
+ assert(new Time(1350).floor(new Duration(200), new Time(50)) == new Time(1250))
+ assert(new Time(1350).floor(new Duration(200), new Time(150)) == new Time(1350))
+ assert(new Time(1350).floor(new Duration(200), new Time(200)) == new Time(1200))
}
test("isMultipleOf") {
diff --git a/tox.ini b/tox.ini
index b568029a204cc..76e3f42cde62d 100644
--- a/tox.ini
+++ b/tox.ini
@@ -15,4 +15,4 @@
[pep8]
max-line-length=100
-exclude=cloudpickle.py,heapq3.py
+exclude=cloudpickle.py,heapq3.py,shared.py