Skip to content

Commit

Permalink
[SPARK-6428][SQL] Added explicit type for all public methods in sql/core
Browse files Browse the repository at this point in the history
Also implemented equals/hashCode when they are missing.

This is done in order to enable automatic public method type checking.

Author: Reynold Xin <rxin@databricks.com>

Closes #5104 from rxin/sql-hashcode-explicittype and squashes the following commits:

ffce6f3 [Reynold Xin] Code review feedback.
8b36733 [Reynold Xin] [SPARK-6428][SQL] Added explicit type for all public methods.
  • Loading branch information
rxin committed Mar 20, 2015
1 parent 257cde7 commit a95043b
Show file tree
Hide file tree
Showing 53 changed files with 438 additions and 330 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.Star

protected class AttributeEquals(val a: Attribute) {
override def hashCode() = a match {
Expand Down Expand Up @@ -115,7 +114,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
// sorts of things in its closure.
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq

override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}"
override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}"

override def isEmpty: Boolean = baseSet.isEmpty
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,27 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
result
}

override def equals(o: Any): Boolean = o match {
case other: Row =>
if (values.length != other.length) {
return false
}

var i = 0
while (i < values.length) {
if (isNullAt(i) != other.isNullAt(i)) {
return false
}
if (apply(i) != other.apply(i)) {
return false
}
i += 1
}
true

case _ => false
}

def copy() = this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}

override def equals(other: Any) = other match {
override def equals(other: Any): Boolean = other match {
case d: Decimal =>
compare(d) == 0
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ protected[sql] object NativeType {


protected[sql] trait PrimitiveType extends DataType {
override def isPrimitive = true
override def isPrimitive: Boolean = true
}


Expand Down Expand Up @@ -442,7 +442,7 @@ class TimestampType private() extends NativeType {
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }

private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y)
}

/**
Expand Down Expand Up @@ -542,7 +542,7 @@ class LongType private() extends IntegralType {
*/
override def defaultSize: Int = 8

override def simpleString = "bigint"
override def simpleString: String = "bigint"

private[spark] override def asNullable: LongType = this
}
Expand Down Expand Up @@ -572,7 +572,7 @@ class IntegerType private() extends IntegralType {
*/
override def defaultSize: Int = 4

override def simpleString = "int"
override def simpleString: String = "int"

private[spark] override def asNullable: IntegerType = this
}
Expand Down Expand Up @@ -602,7 +602,7 @@ class ShortType private() extends IntegralType {
*/
override def defaultSize: Int = 2

override def simpleString = "smallint"
override def simpleString: String = "smallint"

private[spark] override def asNullable: ShortType = this
}
Expand Down Expand Up @@ -632,7 +632,7 @@ class ByteType private() extends IntegralType {
*/
override def defaultSize: Int = 1

override def simpleString = "tinyint"
override def simpleString: String = "tinyint"

private[spark] override def asNullable: ByteType = this
}
Expand Down Expand Up @@ -696,7 +696,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
*/
override def defaultSize: Int = 4096

override def simpleString = precisionInfo match {
override def simpleString: String = precisionInfo match {
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
case None => "decimal(10,0)"
}
Expand Down Expand Up @@ -836,7 +836,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
*/
override def defaultSize: Int = 100 * elementType.defaultSize

override def simpleString = s"array<${elementType.simpleString}>"
override def simpleString: String = s"array<${elementType.simpleString}>"

private[spark] override def asNullable: ArrayType =
ArrayType(elementType.asNullable, containsNull = true)
Expand Down Expand Up @@ -1065,7 +1065,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
*/
override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum

override def simpleString = {
override def simpleString: String = {
val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}")
s"struct<${fieldTypes.mkString(",")}>"
}
Expand Down Expand Up @@ -1142,7 +1142,7 @@ case class MapType(
*/
override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)

override def simpleString = s"map<${keyType.simpleString},${valueType.simpleString}>"
override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"

private[spark] override def asNullable: MapType =
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Column(protected[sql] val expr: Expression) {

override def toString: String = expr.prettyString

override def equals(that: Any) = that match {
override def equals(that: Any): Boolean = that match {
case that: Column => that.expr.equals(this.expr)
case _ => false
}
Expand Down
6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser}
import org.apache.spark.sql.catalyst.{expressions, ScalaReflection, SqlParser}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
Expand Down Expand Up @@ -722,7 +722,7 @@ class DataFrame private[sql](
: DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
def rowFunction(row: Row) = {
def rowFunction(row: Row): TraversableOnce[Row] = {
f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType)))
}
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
Expand Down Expand Up @@ -1155,7 +1155,7 @@ class DataFrame private[sql](
val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)

new Iterator[String] {
override def hasNext = iter.hasNext
override def hasNext: Boolean = iter.hasNext
override def next(): String = {
JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
gen.flush()
Expand Down
8 changes: 4 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class SQLContext(@transient val sparkContext: SparkContext)

@transient
protected[sql] val tlSession = new ThreadLocal[SQLSession]() {
override def initialValue = defaultSession
override def initialValue: SQLSession = defaultSession
}

@transient
Expand Down Expand Up @@ -988,9 +988,9 @@ class SQLContext(@transient val sparkContext: SparkContext)

val sqlContext: SQLContext = self

def codegenEnabled = self.conf.codegenEnabled
def codegenEnabled: Boolean = self.conf.codegenEnabled

def numPartitions = self.conf.numShufflePartitions
def numPartitions: Int = self.conf.numShufflePartitions

def strategies: Seq[Strategy] =
experimental.extraStrategies ++ (
Expand Down Expand Up @@ -1109,7 +1109,7 @@ class SQLContext(@transient val sparkContext: SparkContext)

lazy val analyzed: LogicalPlan = analyzer(logical)
lazy val withCachedData: LogicalPlan = {
assertAnalyzed
assertAnalyzed()
cacheManager.useCachedData(analyzed)
}
lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {

val dataType = sqlContext.parseDataType(stringDataType)

def builder(e: Seq[Expression]) =
def builder(e: Seq[Expression]): PythonUDF =
PythonUDF(
name,
command,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](

protected def initialize() {}

def hasNext = buffer.hasRemaining
override def hasNext: Boolean = buffer.hasRemaining

def extractTo(row: MutableRow, ordinal: Int): Unit = {
override def extractTo(row: MutableRow, ordinal: Int): Unit = {
extractSingle(row, ordinal)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
override def initialize(
initialSize: Int,
columnName: String = "",
useCompression: Boolean = false) = {
useCompression: Boolean = false): Unit = {

val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize
this.columnName = columnName
Expand All @@ -73,7 +73,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
columnType.append(row, ordinal, buffer)
}

override def build() = {
override def build(): ByteBuffer = {
buffer.flip().asInstanceOf[ByteBuffer]
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private[sql] sealed trait ColumnStats extends Serializable {
private[sql] class NoopColumnStats extends ColumnStats {
override def gatherStats(row: Row, ordinal: Int): Unit = super.gatherStats(row, ordinal)

def collectedStatistics = Row(null, null, nullCount, count, 0L)
override def collectedStatistics: Row = Row(null, null, nullCount, count, 0L)
}

private[sql] class BooleanColumnStats extends ColumnStats {
Expand All @@ -93,7 +93,7 @@ private[sql] class BooleanColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class ByteColumnStats extends ColumnStats {
Expand All @@ -110,7 +110,7 @@ private[sql] class ByteColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class ShortColumnStats extends ColumnStats {
Expand All @@ -127,7 +127,7 @@ private[sql] class ShortColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class LongColumnStats extends ColumnStats {
Expand All @@ -144,7 +144,7 @@ private[sql] class LongColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class DoubleColumnStats extends ColumnStats {
Expand All @@ -161,7 +161,7 @@ private[sql] class DoubleColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class FloatColumnStats extends ColumnStats {
Expand All @@ -178,7 +178,7 @@ private[sql] class FloatColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class FixedDecimalColumnStats extends ColumnStats {
Expand Down Expand Up @@ -212,7 +212,7 @@ private[sql] class IntColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class StringColumnStats extends ColumnStats {
Expand All @@ -229,7 +229,7 @@ private[sql] class StringColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class DateColumnStats extends IntColumnStats
Expand All @@ -248,7 +248,7 @@ private[sql] class TimestampColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class BinaryColumnStats extends ColumnStats {
Expand All @@ -259,7 +259,7 @@ private[sql] class BinaryColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes)
}

private[sql] class GenericColumnStats extends ColumnStats {
Expand All @@ -270,5 +270,5 @@ private[sql] class GenericColumnStats extends ColumnStats {
}
}

def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes)
override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes)
}
Loading

0 comments on commit a95043b

Please sign in to comment.