diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
index edbdda8a0bcb6..34ee3a48f8e74 100644
--- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
+++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
@@ -45,8 +45,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) {
*/
def getJobIdsForGroup(jobGroup: String): Array[Int] = {
jobProgressListener.synchronized {
- val jobData = jobProgressListener.jobIdToData.valuesIterator
- jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray
+ jobProgressListener.jobGroupToJobIds.getOrElse(jobGroup, Seq.empty).toArray
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 0997507d016f5..9db6fd1ac4dbe 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -101,6 +101,8 @@ private[deploy] object DeployMessages {
case class RegisterApplication(appDescription: ApplicationDescription)
extends DeployMessage
+ case class UnregisterApplication(appId: String)
+
case class MasterChangeAcknowledged(appId: String)
// Master to AppClient
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index 3b729725257ef..4f06d7f96c46e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -157,6 +157,7 @@ private[spark] class AppClient(
case StopAppClient =>
markDead("Application has been stopped.")
+ master ! UnregisterApplication(appId)
sender ! true
context.stop(self)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index 536aedb6f9fe9..bc5b293379f2b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -91,7 +91,7 @@ private[deploy] class ApplicationInfo(
}
}
- private[master] val requestedCores = desc.maxCores.getOrElse(defaultCores)
+ private val requestedCores = desc.maxCores.getOrElse(defaultCores)
private[master] def coresLeft: Int = requestedCores - coresGranted
@@ -111,6 +111,10 @@ private[deploy] class ApplicationInfo(
endTime = System.currentTimeMillis()
}
+ private[master] def isFinished: Boolean = {
+ state != ApplicationState.WAITING && state != ApplicationState.RUNNING
+ }
+
def duration: Long = {
if (endTime != -1) {
endTime - startTime
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 80506621f4d24..9a5d5877da86d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -339,7 +339,11 @@ private[master] class Master(
if (ExecutorState.isFinished(state)) {
// Remove this executor from the worker and app
logInfo(s"Removing executor ${exec.fullId} because it is $state")
- appInfo.removeExecutor(exec)
+ // If an application has already finished, preserve its
+ // state to display its information properly on the UI
+ if (!appInfo.isFinished) {
+ appInfo.removeExecutor(exec)
+ }
exec.worker.removeExecutor(exec)
val normalExit = exitStatus == Some(0)
@@ -428,6 +432,10 @@ private[master] class Master(
if (canCompleteRecovery) { completeRecovery() }
}
+ case UnregisterApplication(applicationId) =>
+ logInfo(s"Received unregister request from application $applicationId")
+ idToApp.get(applicationId).foreach(finishApplication)
+
case DisassociatedEvent(_, address, _) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
logInfo(s"$address got disassociated, removing it.")
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index 46509e39c0f23..45412a35e9a7d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -75,16 +75,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
val workers = state.workers.sortBy(_.id)
val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
- val activeAppHeaders = Seq("Application ID", "Name", "Cores in Use",
- "Cores Requested", "Memory per Node", "Submitted Time", "User", "State", "Duration")
+ val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time",
+ "User", "State", "Duration")
val activeApps = state.activeApps.sortBy(_.startTime).reverse
- val activeAppsTable = UIUtils.listingTable(activeAppHeaders, activeAppRow, activeApps)
-
- val completedAppHeaders = Seq("Application ID", "Name", "Cores Requested", "Memory per Node",
- "Submitted Time", "User", "State", "Duration")
+ val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps)
val completedApps = state.completedApps.sortBy(_.endTime).reverse
- val completedAppsTable = UIUtils.listingTable(completedAppHeaders, completeAppRow,
- completedApps)
+ val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps)
val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores",
"Memory", "Main Class")
@@ -191,7 +187,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
- private def appRow(app: ApplicationInfo, active: Boolean): Seq[Node] = {
+ private def appRow(app: ApplicationInfo): Seq[Node] = {
val killLink = if (parent.killEnabled &&
(app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) {
val killLinkUri = s"app/kill?id=${app.id}&terminate=true"
@@ -201,7 +197,6 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
(kill)
}
-
{app.id}
@@ -210,15 +205,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
|
{app.desc.name}
|
- {
- if (active) {
-
- {app.coresGranted}
- |
- }
- }
- {if (app.requestedCores == Int.MaxValue) "*" else app.requestedCores}
+ {app.coresGranted}
|
{Utils.megabytesToString(app.desc.memoryPerSlave)}
@@ -230,14 +218,6 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
|
}
- private def activeAppRow(app: ApplicationInfo): Seq[Node] = {
- appRow(app, active = true)
- }
-
- private def completeAppRow(app: ApplicationInfo): Seq[Node] = {
- appRow(app, active = false)
- }
-
private def driverRow(driver: DriverInfo): Seq[Node] = {
val killLink = if (parent.killEnabled &&
(driver.state == DriverState.RUNNING ||
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 6fa1f2c880f7a..132a9ced77700 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -81,9 +81,11 @@ class TaskInfo(
def status: String = {
if (running) {
- "RUNNING"
- } else if (gettingResult) {
- "GET RESULT"
+ if (gettingResult) {
+ "GET RESULT"
+ } else {
+ "RUNNING"
+ }
} else if (failed) {
"FAILED"
} else if (successful) {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 660df00bc32f5..d0178dfde6935 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -112,6 +112,7 @@ class FileShuffleBlockManager(conf: SparkConf)
private val shuffleState = shuffleStates(shuffleId)
private var fileGroup: ShuffleFileGroup = null
+ val openStartTime = System.nanoTime
val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
fileGroup = getUnusedFileGroup()
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
@@ -135,6 +136,9 @@ class FileShuffleBlockManager(conf: SparkConf)
blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
}
}
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, so should be included in the shuffle write time.
+ writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
override def releaseWriters(success: Boolean) {
if (consolidateShuffleFiles) {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index fa2e617762f55..55ea0f17b156a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -63,6 +63,9 @@ private[spark] class SortShuffleWriter[K, V, C](
sorter.insertAll(records)
}
+ // Don't bother including the time to open the merged output file in the shuffle write time,
+ // because it just opens a single file, so is typically too fast to measure accurately
+ // (see SPARK-3570).
val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 80d66e59132da..1dff09a75d038 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -535,9 +535,14 @@ private[spark] class BlockManager(
/* We'll store the bytes in memory if the block's storage level includes
* "memory serialized", or if it should be cached as objects in memory
* but we only requested its serialized bytes. */
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
+ memoryStore.putBytes(blockId, bytes.limit, () => {
+ // https://issues.apache.org/jira/browse/SPARK-6076
+ // If the file size is bigger than the free memory, OOM will happen. So if we cannot
+ // put it into MemoryStore, copyForMemory should not be created. That's why this
+ // action is put into a `() => ByteBuffer` and created lazily.
+ val copyForMemory = ByteBuffer.allocate(bytes.limit)
+ copyForMemory.put(bytes)
+ })
bytes.rewind()
}
if (!asBlockResult) {
@@ -991,15 +996,23 @@ private[spark] class BlockManager(
putIterator(blockId, Iterator(value), level, tellMaster)
}
+ def dropFromMemory(
+ blockId: BlockId,
+ data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = {
+ dropFromMemory(blockId, () => data)
+ }
+
/**
* Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
* store reaches its limit and needs to free up space.
*
+ * If `data` is not put on disk, it won't be created.
+ *
* Return the block status if the given block has been updated, else None.
*/
def dropFromMemory(
blockId: BlockId,
- data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = {
+ data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = {
logInfo(s"Dropping block $blockId from memory")
val info = blockInfo.get(blockId).orNull
@@ -1023,7 +1036,7 @@ private[spark] class BlockManager(
// Drop to disk, if storage level requires
if (level.useDisk && !diskStore.contains(blockId)) {
logInfo(s"Writing block $blockId to disk")
- data match {
+ data() match {
case Left(elements) =>
diskStore.putArray(blockId, elements, level, returnValues = false)
case Right(bytes) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 1be860aea63d0..ed609772e6979 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -98,6 +98,26 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
+ /**
+ * Use `size` to test if there is enough space in MemoryStore. If so, create the ByteBuffer and
+ * put it into MemoryStore. Otherwise, the ByteBuffer won't be created.
+ *
+ * The caller should guarantee that `size` is correct.
+ */
+ def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = {
+ // Work on a duplicate - since the original input might be used elsewhere.
+ lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer]
+ val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false)
+ val data =
+ if (putAttempt.success) {
+ assert(bytes.limit == size)
+ Right(bytes.duplicate())
+ } else {
+ null
+ }
+ PutResult(size, data, putAttempt.droppedBlocks)
+ }
+
override def putArray(
blockId: BlockId,
values: Array[Any],
@@ -312,11 +332,22 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
blockId.asRDDId.map(_.rddId)
}
+ private def tryToPut(
+ blockId: BlockId,
+ value: Any,
+ size: Long,
+ deserialized: Boolean): ResultWithDroppedBlocks = {
+ tryToPut(blockId, () => value, size, deserialized)
+ }
+
/**
* Try to put in a set of values, if we can free up enough space. The value should either be
* an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size
* must also be passed by the caller.
*
+ * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be
+ * created to avoid OOM since it may be a big ByteBuffer.
+ *
* Synchronize on `accountingLock` to ensure that all the put requests and its associated block
* dropping is done by only on thread at a time. Otherwise while one thread is dropping
* blocks to free memory for one block, another thread may use up the freed space for
@@ -326,7 +357,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
*/
private def tryToPut(
blockId: BlockId,
- value: Any,
+ value: () => Any,
size: Long,
deserialized: Boolean): ResultWithDroppedBlocks = {
@@ -345,7 +376,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
droppedBlocks ++= freeSpaceResult.droppedBlocks
if (enoughFreeSpace) {
- val entry = new MemoryEntry(value, size, deserialized)
+ val entry = new MemoryEntry(value(), size, deserialized)
entries.synchronized {
entries.put(blockId, entry)
currentMemory += size
@@ -357,12 +388,12 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} else {
// Tell the block manager that we couldn't put it in memory so that it can drop it to
// disk if the block allows disk storage.
- val data = if (deserialized) {
- Left(value.asInstanceOf[Array[Any]])
+ lazy val data = if (deserialized) {
+ Left(value().asInstanceOf[Array[Any]])
} else {
- Right(value.asInstanceOf[ByteBuffer].duplicate())
+ Right(value().asInstanceOf[ByteBuffer].duplicate())
}
- val droppedBlockStatus = blockManager.dropFromMemory(blockId, data)
+ val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data)
droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
}
// Release the unroll memory used because we no longer need the underlying Array
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index 19ac7a826e306..5fbcd6bb8ad94 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ui
+import java.util.concurrent.Semaphore
+
import scala.util.Random
import org.apache.spark.{SparkConf, SparkContext}
@@ -88,6 +90,8 @@ private[spark] object UIWorkloadGenerator {
("Job with delays", baseData.map(x => Thread.sleep(100)).count)
)
+ val barrier = new Semaphore(-nJobSet * jobs.size + 1)
+
(1 to nJobSet).foreach { _ =>
for ((desc, job) <- jobs) {
new Thread {
@@ -99,12 +103,17 @@ private[spark] object UIWorkloadGenerator {
} catch {
case e: Exception =>
println("Job Failed: " + desc)
+ } finally {
+ barrier.release()
}
}
}.start
Thread.sleep(INTER_JOB_WAIT_MS)
}
}
+
+ // Waiting for threads.
+ barrier.acquire()
sc.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 949e80d30f5eb..625596885faa1 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -44,6 +44,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
// These type aliases are public because they're used in the types of public fields:
type JobId = Int
+ type JobGroupId = String
type StageId = Int
type StageAttemptId = Int
type PoolName = String
@@ -54,6 +55,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val completedJobs = ListBuffer[JobUIData]()
val failedJobs = ListBuffer[JobUIData]()
val jobIdToData = new HashMap[JobId, JobUIData]
+ val jobGroupToJobIds = new HashMap[JobGroupId, HashSet[JobId]]
// Stages:
val pendingStages = new HashMap[StageId, StageInfo]
@@ -119,7 +121,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
Map(
"jobIdToData" -> jobIdToData.size,
"stageIdToData" -> stageIdToData.size,
- "stageIdToStageInfo" -> stageIdToInfo.size
+ "stageIdToStageInfo" -> stageIdToInfo.size,
+ "jobGroupToJobIds" -> jobGroupToJobIds.values.map(_.size).sum,
+ // Since jobGroupToJobIds is map of sets, check that we don't leak keys with empty values:
+ "jobGroupToJobIds keySet" -> jobGroupToJobIds.keys.size
)
}
@@ -140,7 +145,19 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
if (jobs.size > retainedJobs) {
val toRemove = math.max(retainedJobs / 10, 1)
jobs.take(toRemove).foreach { job =>
- jobIdToData.remove(job.jobId)
+ // Remove the job's UI data, if it exists
+ jobIdToData.remove(job.jobId).foreach { removedJob =>
+ // A null jobGroupId is used for jobs that are run without a job group
+ val jobGroupId = removedJob.jobGroup.orNull
+ // Remove the job group -> job mapping entry, if it exists
+ jobGroupToJobIds.get(jobGroupId).foreach { jobsInGroup =>
+ jobsInGroup.remove(job.jobId)
+ // If this was the last job in this job group, remove the map entry for the job group
+ if (jobsInGroup.isEmpty) {
+ jobGroupToJobIds.remove(jobGroupId)
+ }
+ }
+ }
}
jobs.trimStart(toRemove)
}
@@ -158,6 +175,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageIds = jobStart.stageIds,
jobGroup = jobGroup,
status = JobExecutionStatus.RUNNING)
+ // A null jobGroupId is used for jobs that are run without a job group
+ jobGroupToJobIds.getOrElseUpdate(jobGroup.orNull, new HashSet[JobId]).add(jobStart.jobId)
jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x)
// Compute (a potential underestimate of) the number of tasks that will be run by this job.
// This may be an underestimate because the job start event references all of the result
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index e03442894c5cc..797c9404bc449 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -269,11 +269,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+: getFormattedTimeQuantiles(serializationTimes)
val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) =>
- if (info.gettingResultTime > 0) {
- (info.finishTime - info.gettingResultTime).toDouble
- } else {
- 0.0
- }
+ getGettingResultTime(info).toDouble
}
val gettingResultQuantiles =
@@ -464,7 +460,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L)
val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L)
val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
- val gettingResultTime = info.gettingResultTime
+ val gettingResultTime = getGettingResultTime(info)
val maybeAccumulators = info.accumulables
val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"}
@@ -627,6 +623,19 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
| {errorSummary}{details} |
}
+ private def getGettingResultTime(info: TaskInfo): Long = {
+ if (info.gettingResultTime > 0) {
+ if (info.finishTime > 0) {
+ info.finishTime - info.gettingResultTime
+ } else {
+ // The task is still fetching the result.
+ System.currentTimeMillis - info.gettingResultTime
+ }
+ } else {
+ 0L
+ }
+ }
+
private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = {
val totalExecutionTime =
if (info.gettingResult) {
@@ -638,6 +647,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
}
val executorOverhead = (metrics.executorDeserializeTime +
metrics.resultSerializationTime)
- math.max(0, totalExecutionTime - metrics.executorRunTime - executorOverhead)
+ math.max(
+ 0,
+ totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info))
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index d9a671687aad0..0b5a914e7dbbf 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1876,6 +1876,10 @@ private[spark] object Utils extends Logging {
startService: Int => (T, Int),
conf: SparkConf,
serviceName: String = ""): (T, Int) = {
+
+ require(startPort == 0 || (1024 <= startPort && startPort < 65536),
+ "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.")
+
val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
val maxRetries = portMaxRetries(conf)
for (offset <- 0 to maxRetries) {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 3262e670c2030..b962c101c91da 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -352,6 +352,7 @@ private[spark] class ExternalSorter[K, V, C](
// Create our file writers if we haven't done so yet
if (partitionWriters == null) {
curWriteMetrics = new ShuffleWriteMetrics()
+ val openStartTime = System.nanoTime
partitionWriters = Array.fill(numPartitions) {
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
@@ -359,6 +360,10 @@ private[spark] class ExternalSorter[K, V, C](
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open()
}
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
}
// No need to sort stuff, just write each element out
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index c52591b352340..efc2482c74ddf 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -53,6 +53,15 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
+ /** Tests whether this map contains a binding for a key. */
+ def contains(k: K): Boolean = {
+ if (k == null) {
+ haveNullValue
+ } else {
+ _keySet.getPos(k) != OpenHashSet.INVALID_POS
+ }
+ }
+
/** Get the value for a given key */
def apply(k: K): V = {
if (k == null) {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index c80057f95e0b2..1501111a06655 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -122,7 +122,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
*/
def addWithoutResize(k: T): Int = {
var pos = hashcode(hasher.hash(k)) & _mask
- var i = 1
+ var delta = 1
while (true) {
if (!_bitset.get(pos)) {
// This is a new key.
@@ -134,14 +134,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
// Found an existing key.
return pos
} else {
- val delta = i
+ // quadratic probing with values increase by 1, 2, 3, ...
pos = (pos + delta) & _mask
- i += 1
+ delta += 1
}
}
- // Never reached here
- assert(INVALID_POS != INVALID_POS)
- INVALID_POS
+ throw new RuntimeException("Should never reach here.")
}
/**
@@ -163,21 +161,19 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
*/
def getPos(k: T): Int = {
var pos = hashcode(hasher.hash(k)) & _mask
- var i = 1
- val maxProbe = _data.size
- while (i < maxProbe) {
+ var delta = 1
+ while (true) {
if (!_bitset.get(pos)) {
return INVALID_POS
} else if (k == _data(pos)) {
return pos
} else {
- val delta = i
+ // quadratic probing with values increase by 1, 2, 3, ...
pos = (pos + delta) & _mask
- i += 1
+ delta += 1
}
}
- // Never reached here
- INVALID_POS
+ throw new RuntimeException("Should never reach here.")
}
/** Return the value at the specified position. */
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
index 61e22642761f0..b4ec4ea521253 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -48,6 +48,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
override def size: Int = _keySet.size
+ /** Tests whether this map contains a binding for a key. */
+ def contains(k: K): Boolean = {
+ _keySet.getPos(k) != OpenHashSet.INVALID_POS
+ }
+
/** Get the value for a given key */
def apply(k: K): V = {
val pos = _keySet.getPos(k)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 3fdbe99b5d02b..ecd1cba5b5abe 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -170,8 +170,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
assert(master.getLocations("a3").size === 0, "master was told about a3")
// Drop a1 and a2 from memory; this should be reported back to the master
- store.dropFromMemory("a1", null)
- store.dropFromMemory("a2", null)
+ store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer])
+ store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer])
assert(store.getSingle("a1") === None, "a1 not removed from store")
assert(store.getSingle("a2") === None, "a2 not removed from store")
assert(master.getLocations("a1").size === 0, "master did not remove a1")
@@ -413,8 +413,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
t2.join()
t3.join()
- store.dropFromMemory("a1", null)
- store.dropFromMemory("a2", null)
+ store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer])
+ store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer])
store.waitForAsyncReregister()
}
}
@@ -1223,4 +1223,30 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
assert(unrollMemoryAfterB6 === unrollMemoryAfterB4)
assert(unrollMemoryAfterB7 === unrollMemoryAfterB4)
}
+
+ test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") {
+ store = makeBlockManager(12000)
+ val memoryStore = store.memoryStore
+ val blockId = BlockId("rdd_3_10")
+ val result = memoryStore.putBytes(blockId, 13000, () => {
+ fail("A big ByteBuffer that cannot be put into MemoryStore should not be created")
+ })
+ assert(result.size === 13000)
+ assert(result.data === null)
+ assert(result.droppedBlocks === Nil)
+ }
+
+ test("put a small ByteBuffer to MemoryStore") {
+ store = makeBlockManager(12000)
+ val memoryStore = store.memoryStore
+ val blockId = BlockId("rdd_3_10")
+ var bytes: ByteBuffer = null
+ val result = memoryStore.putBytes(blockId, 10000, () => {
+ bytes = ByteBuffer.allocate(10000)
+ bytes
+ })
+ assert(result.size === 10000)
+ assert(result.data === Right(bytes))
+ assert(result.droppedBlocks === Nil)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 730a4b54f5aa1..c0c28cb60e21d 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ui.jobs
+import java.util.Properties
+
import org.scalatest.FunSuite
import org.scalatest.Matchers
@@ -44,11 +46,19 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
SparkListenerStageCompleted(stageInfo)
}
- private def createJobStartEvent(jobId: Int, stageIds: Seq[Int]) = {
+ private def createJobStartEvent(
+ jobId: Int,
+ stageIds: Seq[Int],
+ jobGroup: Option[String] = None): SparkListenerJobStart = {
val stageInfos = stageIds.map { stageId =>
new StageInfo(stageId, 0, stageId.toString, 0, null, "")
}
- SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos)
+ val properties: Option[Properties] = jobGroup.map { groupId =>
+ val props = new Properties()
+ props.setProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
+ props
+ }
+ SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos, properties.orNull)
}
private def createJobEndEvent(jobId: Int, failed: Boolean = false) = {
@@ -110,6 +120,23 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
listener.stageIdToActiveJobIds.size should be (0)
}
+ test("test clearing of jobGroupToJobIds") {
+ val conf = new SparkConf()
+ conf.set("spark.ui.retainedJobs", 5.toString)
+ val listener = new JobProgressListener(conf)
+
+ // Run 50 jobs, each with one stage
+ for (jobId <- 0 to 50) {
+ listener.onJobStart(createJobStartEvent(jobId, Seq(0), jobGroup = Some(jobId.toString)))
+ listener.onStageSubmitted(createStageStartEvent(0))
+ listener.onStageCompleted(createStageEndEvent(0, failed = false))
+ listener.onJobEnd(createJobEndEvent(jobId, false))
+ }
+ assertActiveJobsStateIsEmpty(listener)
+ // This collection won't become empty, but it should be bounded by spark.ui.retainedJobs
+ listener.jobGroupToJobIds.size should be (5)
+ }
+
test("test LRU eviction of jobs") {
val conf = new SparkConf()
conf.set("spark.ui.retainedStages", 5.toString)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
index 6a70877356409..ef890d2ba60f3 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -176,4 +176,14 @@ class OpenHashMapSuite extends FunSuite with Matchers {
assert(map(i.toString) === i.toString)
}
}
+
+ test("contains") {
+ val map = new OpenHashMap[String, Int](2)
+ map("a") = 1
+ assert(map.contains("a"))
+ assert(!map.contains("b"))
+ assert(!map.contains(null))
+ map(null) = 0
+ assert(map.contains(null))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
index 8c7df7d73dcd3..caf378fec8b3e 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -118,4 +118,11 @@ class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers {
assert(map(i.toLong) === i.toString)
}
}
+
+ test("contains") {
+ val map = new PrimitiveKeyOpenHashMap[Int, Int](1)
+ map(0) = 0
+ assert(map.contains(0))
+ assert(!map.contains(1))
+ }
}
diff --git a/dev/run-tests b/dev/run-tests
index d6935a61c6d29..561d7fc9e7b1f 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -178,6 +178,15 @@ CURRENT_BLOCK=$BLOCK_BUILD
fi
}
+echo ""
+echo "========================================================================="
+echo "Detecting binary incompatibilities with MiMa"
+echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_MIMA
+
+./dev/mima
+
echo ""
echo "========================================================================="
echo "Running Spark unit tests"
@@ -227,12 +236,3 @@ echo "========================================================================="
CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS
./python/run-tests
-
-echo ""
-echo "========================================================================="
-echo "Detecting binary incompatibilities with MiMa"
-echo "========================================================================="
-
-CURRENT_BLOCK=$BLOCK_MIMA
-
-./dev/mima
diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh
index 1348e0609dda4..8ab6db6925d6e 100644
--- a/dev/run-tests-codes.sh
+++ b/dev/run-tests-codes.sh
@@ -22,6 +22,6 @@ readonly BLOCK_RAT=11
readonly BLOCK_SCALA_STYLE=12
readonly BLOCK_PYTHON_STYLE=13
readonly BLOCK_BUILD=14
-readonly BLOCK_SPARK_UNIT_TESTS=15
-readonly BLOCK_PYSPARK_UNIT_TESTS=16
-readonly BLOCK_MIMA=17
+readonly BLOCK_MIMA=15
+readonly BLOCK_SPARK_UNIT_TESTS=16
+readonly BLOCK_PYSPARK_UNIT_TESTS=17
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
index 5f4000e83925c..3a937b637e003 100755
--- a/dev/run-tests-jenkins
+++ b/dev/run-tests-jenkins
@@ -199,12 +199,12 @@ done
failing_test="Python style tests"
elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then
failing_test="to build"
+ elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then
+ failing_test="MiMa tests"
elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then
failing_test="Spark unit tests"
elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then
failing_test="PySpark unit tests"
- elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then
- failing_test="MiMa tests"
else
failing_test="some tests"
fi
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index da6aef7f14c4c..c08c76d226713 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -408,31 +408,31 @@ import org.apache.spark.sql.SQLContext;
// Labeled and unlabeled instance types.
// Spark SQL can infer schema from Java Beans.
public class Document implements Serializable {
- private Long id;
+ private long id;
private String text;
- public Document(Long id, String text) {
+ public Document(long id, String text) {
this.id = id;
this.text = text;
}
- public Long getId() { return this.id; }
- public void setId(Long id) { this.id = id; }
+ public long getId() { return this.id; }
+ public void setId(long id) { this.id = id; }
public String getText() { return this.text; }
public void setText(String text) { this.text = text; }
}
public class LabeledDocument extends Document implements Serializable {
- private Double label;
+ private double label;
- public LabeledDocument(Long id, String text, Double label) {
+ public LabeledDocument(long id, String text, double label) {
super(id, text);
this.label = label;
}
- public Double getLabel() { return this.label; }
- public void setLabel(Double label) { this.label = label; }
+ public double getLabel() { return this.label; }
+ public void setLabel(double label) { this.label = label; }
}
// Set up contexts.
@@ -565,6 +565,11 @@ import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}
+// Labeled and unlabeled instance types.
+// Spark SQL can infer schema from case classes.
+case class LabeledDocument(id: Long, text: String, label: Double)
+case class Document(id: Long, text: String)
+
val conf = new SparkConf().setAppName("CrossValidatorExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
@@ -655,6 +660,36 @@ import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
+// Labeled and unlabeled instance types.
+// Spark SQL can infer schema from Java Beans.
+public class Document implements Serializable {
+ private long id;
+ private String text;
+
+ public Document(long id, String text) {
+ this.id = id;
+ this.text = text;
+ }
+
+ public long getId() { return this.id; }
+ public void setId(long id) { this.id = id; }
+
+ public String getText() { return this.text; }
+ public void setText(String text) { this.text = text; }
+}
+
+public class LabeledDocument extends Document implements Serializable {
+ private double label;
+
+ public LabeledDocument(long id, String text, double label) {
+ super(id, text);
+ this.label = label;
+ }
+
+ public double getLabel() { return this.label; }
+ public void setLabel(double label) { this.label = label; }
+}
+
SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 0b6db4fcb7b1f..f5aa15b7d9b79 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -173,6 +173,7 @@ to the algorithm. We then output the parameters of the mixture model.
{% highlight scala %}
import org.apache.spark.mllib.clustering.GaussianMixture
+import org.apache.spark.mllib.clustering.GaussianMixtureModel
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
@@ -182,6 +183,10 @@ val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble)))
// Cluster the data into two classes using GaussianMixture
val gmm = new GaussianMixture().setK(2).run(parsedData)
+// Save and load model
+gmm.save(sc, "myGMMModel")
+val sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
+
// output parameters of max-likelihood model
for (i <- 0 until gmm.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
@@ -231,6 +236,9 @@ public class GaussianMixtureExample {
// Cluster the data into two classes using GaussianMixture
GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());
+ // Save and load GaussianMixtureModel
+ gmm.save(sc, "myGMMModel")
+ GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
// Output the parameters of the mixture model
for(int j=0; j
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 68b1aeb8ebd01..d9f3eb2b74b18 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -274,6 +274,6 @@ If you need a reference to the proper location to put log files in the YARN so t
# Important notes
- Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured.
-- The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored.
+- In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do.
- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN.
- The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files.
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 6a333fdb562a7..c99a0b03442c4 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -624,7 +624,8 @@ tuples or lists in the RDD created in the step 1.
For example:
{% highlight python %}
# Import SQLContext and data types
-from pyspark.sql import *
+from pyspark.sql import SQLContext
+from pyspark.sql.types import *
# sc is an existing SparkContext.
sqlContext = SQLContext(sc)
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index 322de7bf2fed8..51d273af8da84 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -28,6 +28,7 @@ import scala.language.postfixOps
import com.google.common.base.Charsets
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
+import org.apache.commons.lang3.RandomUtils
import org.apache.flume.source.avro
import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol}
import org.jboss.netty.channel.ChannelPipeline
@@ -40,7 +41,6 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}
-import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted}
import org.apache.spark.util.Utils
class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging {
@@ -76,7 +76,8 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L
/** Find a free port */
private def findFreePort(): Int = {
- Utils.startServiceOnPort(23456, (trialPort: Int) => {
+ val candidatePort = RandomUtils.nextInt(1024, 65536)
+ Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
val socket = new ServerSocket(trialPort)
socket.close()
(null, trialPort)
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index 0f3298af6234a..24d78ecb3a97d 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -25,6 +25,7 @@ import scala.concurrent.duration._
import scala.language.postfixOps
import org.apache.activemq.broker.{TransportConnector, BrokerService}
+import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
@@ -113,7 +114,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
}
private def findFreePort(): Int = {
- Utils.startServiceOnPort(23456, (trialPort: Int) => {
+ val candidatePort = RandomUtils.nextInt(1024, 65536)
+ Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
val socket = new ServerSocket(trialPort)
socket.close()
(null, trialPort)
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index dc90e9e987234..2da5f7278729e 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -147,7 +147,6 @@ void addOptionString(List cmd, String options) {
*/
List buildClassPath(String appClassPath) throws IOException {
String sparkHome = getSparkHome();
- String scala = getScalaVersion();
List cp = new ArrayList();
addToClassPath(cp, getenv("SPARK_CLASSPATH"));
@@ -158,6 +157,7 @@ List buildClassPath(String appClassPath) throws IOException {
boolean prependClasses = !isEmpty(getenv("SPARK_PREPEND_CLASSES"));
boolean isTesting = "1".equals(getenv("SPARK_TESTING"));
if (prependClasses || isTesting) {
+ String scala = getScalaVersion();
List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx",
"streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver",
"yarn", "launcher");
@@ -182,7 +182,7 @@ List buildClassPath(String appClassPath) throws IOException {
addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome));
}
- String assembly = findAssembly(scala);
+ String assembly = findAssembly();
addToClassPath(cp, assembly);
// When Hive support is needed, Datanucleus jars must be included on the classpath. Datanucleus
@@ -330,7 +330,7 @@ String getenv(String key) {
return firstNonEmpty(childEnv.get(key), System.getenv(key));
}
- private String findAssembly(String scalaVersion) {
+ private String findAssembly() {
String sparkHome = getSparkHome();
File libdir;
if (new File(sparkHome, "RELEASE").isFile()) {
@@ -338,7 +338,7 @@ private String findAssembly(String scalaVersion) {
checkState(libdir.isDirectory(), "Library directory '%s' does not exist.",
libdir.getAbsolutePath());
} else {
- libdir = new File(sparkHome, String.format("assembly/target/scala-%s", scalaVersion));
+ libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion()));
}
final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar");
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 0b1f90daa7d8e..68401e36950bd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param}
import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
/**
@@ -39,3 +39,67 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
override protected def outputDataType: DataType = new ArrayType(StringType, false)
}
+
+/**
+ * :: AlphaComponent ::
+ * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
+ * or using it to split the text (set matching to false). Optional parameters also allow to fold
+ * the text to lowercase prior to it being tokenized and to filer tokens using a minimal length.
+ * It returns an array of strings that can be empty.
+ * The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true,
+ * lowercase = false, minTokenLength = 1
+ */
+@AlphaComponent
+class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
+
+ /**
+ * param for minimum token length, default is one to avoid returning empty strings
+ * @group param
+ */
+ val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1))
+
+ /** @group setParam */
+ def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)
+
+ /** @group getParam */
+ def getMinTokenLength: Int = get(minTokenLength)
+
+ /**
+ * param sets regex as splitting on gaps (true) or matching tokens (false)
+ * @group param
+ */
+ val gaps: BooleanParam = new BooleanParam(
+ this, "gaps", "Set regex to match gaps or tokens", Some(false))
+
+ /** @group setParam */
+ def setGaps(value: Boolean): this.type = set(gaps, value)
+
+ /** @group getParam */
+ def getGaps: Boolean = get(gaps)
+
+ /**
+ * param sets regex pattern used by tokenizer
+ * @group param
+ */
+ val pattern: Param[String] = new Param(
+ this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+"))
+
+ /** @group setParam */
+ def setPattern(value: String): this.type = set(pattern, value)
+
+ /** @group getParam */
+ def getPattern: String = get(pattern)
+
+ override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str =>
+ val re = paramMap(pattern).r
+ val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
+ val minLength = paramMap(minTokenLength)
+ tokens.filter(_.length >= minLength)
+ }
+
+ override protected def validateInputType(inputType: DataType): Unit = {
+ require(inputType == StringType, s"Input type must be string type but got $inputType.")
+ }
+
+ override protected def outputDataType: DataType = new ArrayType(StringType, false)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 15ca2547d56a8..e39156734794c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -111,9 +111,11 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): JList[Object] = {
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
+ .setValidateData(validateData)
lrAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
@@ -135,8 +137,12 @@ private[python] class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Vector): JList[Object] = {
+ initialWeights: Vector,
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val lassoAlg = new LassoWithSGD()
+ lassoAlg.setIntercept(intercept)
+ .setValidateData(validateData)
lassoAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
@@ -157,8 +163,12 @@ private[python] class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Vector): JList[Object] = {
+ initialWeights: Vector,
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val ridgeAlg = new RidgeRegressionWithSGD()
+ ridgeAlg.setIntercept(intercept)
+ .setValidateData(validateData)
ridgeAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index af6f83c74bb40..ec65a3da689de 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -19,11 +19,17 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseVector => BreezeVector}
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
-import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, Row}
/**
* :: Experimental ::
@@ -41,10 +47,16 @@ import org.apache.spark.rdd.RDD
@Experimental
class GaussianMixtureModel(
val weights: Array[Double],
- val gaussians: Array[MultivariateGaussian]) extends Serializable {
+ val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
-
+
+ override protected def formatVersion = "1.0"
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians)
+ }
+
/** Number of gaussians in mixture */
def k: Int = weights.length
@@ -83,5 +95,79 @@ class GaussianMixtureModel(
p(i) /= pSum
}
p
- }
+ }
+}
+
+@Experimental
+object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
+
+ private object SaveLoadV1_0 {
+
+ case class Data(weight: Double, mu: Vector, sigma: Matrix)
+
+ val formatVersionV1_0 = "1.0"
+
+ val classNameV1_0 = "org.apache.spark.mllib.clustering.GaussianMixtureModel"
+
+ def save(
+ sc: SparkContext,
+ path: String,
+ weights: Array[Double],
+ gaussians: Array[MultivariateGaussian]): Unit = {
+
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Create JSON metadata.
+ val metadata = compact(render
+ (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("k" -> weights.length)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val dataArray = Array.tabulate(weights.length) { i =>
+ Data(weights(i), gaussians(i).mu, gaussians(i).sigma)
+ }
+ sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): GaussianMixtureModel = {
+ val dataPath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val dataFrame = sqlContext.parquetFile(dataPath)
+ val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
+
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ Loader.checkSchema[Data](dataFrame.schema)
+
+ val (weights, gaussians) = dataArray.map {
+ case Row(weight: Double, mu: Vector, sigma: Matrix) =>
+ (weight, new MultivariateGaussian(mu, sigma))
+ }.unzip
+
+ return new GaussianMixtureModel(weights.toArray, gaussians.toArray)
+ }
+ }
+
+ override def load(sc: SparkContext, path: String) : GaussianMixtureModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val k = (metadata \ "k").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+ (loadedClassName, version) match {
+ case (classNameV1_0, "1.0") => {
+ val model = SaveLoadV1_0.load(sc, path)
+ require(model.weights.length == k,
+ s"GaussianMixtureModel requires weights of length $k " +
+ s"got weights of length ${model.weights.length}")
+ require(model.gaussians.length == k,
+ s"GaussianMixtureModel requires gaussians of length $k" +
+ s"got gaussians of length ${model.gaussians.length}")
+ model
+ }
+ case _ => throw new Exception(
+ s"GaussianMixtureModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 45b9ebb4cc0d6..9fd60ff7a0c79 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -211,6 +211,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
+ if (numFeatures < 0) {
+ numFeatures = input.map(_.features.size).first()
+ }
+
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
new file mode 100644
index 0000000000000..3806f650025b2
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaTokenizerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaTokenizerSuite");
+ jsql = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void regexTokenizer() {
+ RegexTokenizer myRegExTokenizer = new RegexTokenizer()
+ .setInputCol("rawText")
+ .setOutputCol("tokens")
+ .setPattern("\\s")
+ .setGaps(true)
+ .setMinTokenLength(3);
+
+ JavaRDD rdd = jsc.parallelize(Lists.newArrayList(
+ new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
+ new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
+ ));
+ DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
+
+ Row[] pairs = myRegExTokenizer.transform(dataset)
+ .select("tokens", "wantedTokens")
+ .collect();
+
+ for (Row r : pairs) {
+ Assert.assertEquals(r.get(0), r.get(1));
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
new file mode 100644
index 0000000000000..bf862b912d326
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.beans.BeanInfo
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+@BeanInfo
+case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) {
+ /** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */
+ def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq)
+}
+
+class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
+ import org.apache.spark.ml.feature.RegexTokenizerSuite._
+
+ @transient var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ test("RegexTokenizer") {
+ val tokenizer = new RegexTokenizer()
+ .setInputCol("rawText")
+ .setOutputCol("tokens")
+
+ val dataset0 = sqlContext.createDataFrame(Seq(
+ TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")),
+ TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct"))
+ ))
+ testRegexTokenizer(tokenizer, dataset0)
+
+ val dataset1 = sqlContext.createDataFrame(Seq(
+ TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")),
+ TokenizerTestData("Te,st. punct", Seq("punct"))
+ ))
+
+ tokenizer.setMinTokenLength(3)
+ testRegexTokenizer(tokenizer, dataset1)
+
+ tokenizer
+ .setPattern("\\s")
+ .setGaps(true)
+ .setMinTokenLength(0)
+ val dataset2 = sqlContext.createDataFrame(Seq(
+ TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")),
+ TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct"))
+ ))
+ testRegexTokenizer(tokenizer, dataset2)
+ }
+}
+
+object RegexTokenizerSuite extends FunSuite {
+
+ def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
+ t.transform(dataset)
+ .select("tokens", "wantedTokens")
+ .collect()
+ .foreach {
+ case Row(tokens, wantedTokens) =>
+ assert(tokens === wantedTokens)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index aaa81da9e273c..a26c52852c4d7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -425,6 +425,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
val model = lr.run(testRDD)
+ val numFeatures = testRDD.map(_.features.size).first()
+ val initialWeights = Vectors.dense(new Array[Double]((numFeatures + 1) * 2))
+ val model2 = lr.run(testRDD, initialWeights)
+
+ LogisticRegressionSuite.checkModelsEqual(model, model2)
+
/**
* The following is the instruction to reproduce the model using R's glmnet package.
*
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index 1b46a4012d731..f356ffa3e3a26 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
test("single cluster") {
@@ -48,13 +49,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
test("two clusters") {
- val data = sc.parallelize(Array(
- Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
- Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
- Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
- Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
- Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
- ))
+ val data = sc.parallelize(GaussianTestData.data)
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
@@ -105,14 +100,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
test("two clusters with sparse data") {
- val data = sc.parallelize(Array(
- Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
- Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
- Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
- Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
- Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
- ))
-
+ val data = sc.parallelize(GaussianTestData.data)
val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray))
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
@@ -138,4 +126,38 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
+
+ test("model save / load") {
+ val data = sc.parallelize(GaussianTestData.data)
+
+ val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ try {
+ gmm.save(sc, path)
+
+ // TODO: GaussianMixtureModel should implement equals/hashcode directly.
+ val sameModel = GaussianMixtureModel.load(sc, path)
+ assert(sameModel.k === gmm.k)
+ (0 until sameModel.k).foreach { i =>
+ assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu)
+ assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma)
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ object GaussianTestData {
+
+ val data = Array(
+ Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+ Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+ Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+ Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+ Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+ )
+
+ }
}
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 414a0ada80787..209f1ee473b5b 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -140,6 +140,13 @@ class LinearRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0,
+ ... miniBatchFraction=1.0, initialWeights=array([1.0]), regParam=0.1, regType="l2",
+ ... intercept=True, validateData=True)
+ >>> abs(lrm.predict(array([0.0])) - 0) < 0.5
+ True
+ >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
"""
def save(self, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel(
@@ -173,7 +180,8 @@ class LinearRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
- initialWeights=None, regParam=0.0, regType=None, intercept=False):
+ initialWeights=None, regParam=0.0, regType=None, intercept=False,
+ validateData=True):
"""
Train a linear regression model on the given data.
@@ -195,15 +203,18 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
(default: None)
- @param intercept: Boolean parameter which indicates the use
+ :param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not). (default: False)
+ :param validateData: Boolean parameter which indicates if the
+ algorithm should validate data before training.
+ (default: True)
"""
def train(rdd, i):
return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations),
float(step), float(miniBatchFraction), i, float(regParam),
- regType, bool(intercept))
+ regType, bool(intercept), bool(validateData))
return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights)
@@ -253,6 +264,13 @@ class LassoModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=100, step=1.0,
+ ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True,
+ ... validateData=True)
+ >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
"""
def save(self, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel(
@@ -273,11 +291,13 @@ class LassoWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, regParam=0.01,
- miniBatchFraction=1.0, initialWeights=None):
+ miniBatchFraction=1.0, initialWeights=None, intercept=False,
+ validateData=True):
"""Train a Lasso regression model on the given data."""
def train(rdd, i):
return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step),
- float(regParam), float(miniBatchFraction), i)
+ float(regParam), float(miniBatchFraction), i, bool(intercept),
+ bool(validateData))
return _regression_train_wrapper(train, LassoModel, data, initialWeights)
@@ -327,6 +347,13 @@ class RidgeRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0,
+ ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True,
+ ... validateData=True)
+ >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
"""
def save(self, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel(
@@ -347,11 +374,13 @@ class RidgeRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, regParam=0.01,
- miniBatchFraction=1.0, initialWeights=None):
+ miniBatchFraction=1.0, initialWeights=None, intercept=False,
+ validateData=True):
"""Train a ridge regression model on the given data."""
def train(rdd, i):
return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step),
- float(regParam), float(miniBatchFraction), i)
+ float(regParam), float(miniBatchFraction), i, bool(intercept),
+ bool(validateData))
return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights)
diff --git a/repl/pom.xml b/repl/pom.xml
index edfa1c7f2c29c..03053b4c3b287 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -84,6 +84,11 @@
scalacheck_${scala.binary.version}
test
+
+ org.mockito
+ mockito-all
+ test
+
diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
index 9805609120005..004941d5f50ae 100644
--- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
@@ -17,9 +17,10 @@
package org.apache.spark.repl
-import java.io.{ByteArrayOutputStream, InputStream, FileNotFoundException}
-import java.net.{URI, URL, URLEncoder}
-import java.util.concurrent.{Executors, ExecutorService}
+import java.io.{IOException, ByteArrayOutputStream, InputStream}
+import java.net.{HttpURLConnection, URI, URL, URLEncoder}
+
+import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
@@ -43,6 +44,9 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
val parentLoader = new ParentClassLoader(parent)
+ // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes
+ private[repl] var httpUrlConnectionTimeoutMillis: Int = -1
+
// Hadoop FileSystem object for our URI, if it isn't using HTTP
var fileSystem: FileSystem = {
if (Set("http", "https", "ftp").contains(uri.getScheme)) {
@@ -71,30 +75,66 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
}
}
+ private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = {
+ val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
+ val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
+ val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
+ newuri.toURL
+ } else {
+ new URL(classUri + "/" + urlEncode(pathInDirectory))
+ }
+ val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(),
+ SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection]
+ // Set the connection timeouts (for testing purposes)
+ if (httpUrlConnectionTimeoutMillis != -1) {
+ connection.setConnectTimeout(httpUrlConnectionTimeoutMillis)
+ connection.setReadTimeout(httpUrlConnectionTimeoutMillis)
+ }
+ connection.connect()
+ try {
+ if (connection.getResponseCode != 200) {
+ // Close the error stream so that the connection is eligible for re-use
+ try {
+ connection.getErrorStream.close()
+ } catch {
+ case ioe: IOException =>
+ logError("Exception while closing error stream", ioe)
+ }
+ throw new ClassNotFoundException(s"Class file not found at URL $url")
+ } else {
+ connection.getInputStream
+ }
+ } catch {
+ case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] =>
+ connection.disconnect()
+ throw e
+ }
+ }
+
+ private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = {
+ val path = new Path(directory, pathInDirectory)
+ if (fileSystem.exists(path)) {
+ fileSystem.open(path)
+ } else {
+ throw new ClassNotFoundException(s"Class file not found at path $path")
+ }
+ }
+
def findClassLocally(name: String): Option[Class[_]] = {
+ val pathInDirectory = name.replace('.', '/') + ".class"
+ var inputStream: InputStream = null
try {
- val pathInDirectory = name.replace('.', '/') + ".class"
- val inputStream = {
+ inputStream = {
if (fileSystem != null) {
- fileSystem.open(new Path(directory, pathInDirectory))
+ getClassFileInputStreamFromFileSystem(pathInDirectory)
} else {
- val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
- val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
- val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
- newuri.toURL
- } else {
- new URL(classUri + "/" + urlEncode(pathInDirectory))
- }
-
- Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager)
- .getInputStream
+ getClassFileInputStreamFromHttpServer(pathInDirectory)
}
}
val bytes = readAndTransformClass(name, inputStream)
- inputStream.close()
Some(defineClass(name, bytes, 0, bytes.length))
} catch {
- case e: FileNotFoundException =>
+ case e: ClassNotFoundException =>
// We did not find the class
logDebug(s"Did not load class $name from REPL class server at $uri", e)
None
@@ -102,6 +142,15 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
// Something bad happened while checking if the class exists
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
None
+ } finally {
+ if (inputStream != null) {
+ try {
+ inputStream.close()
+ } catch {
+ case e: Exception =>
+ logError("Exception while closing inputStream", e)
+ }
+ }
}
}
diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
index 6a79e76a34db8..c709cde740748 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
@@ -20,13 +20,25 @@ package org.apache.spark.repl
import java.io.File
import java.net.{URL, URLClassLoader}
+import scala.concurrent.duration._
+import scala.language.implicitConversions
+import scala.language.postfixOps
+
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
+import org.scalatest.concurrent.Interruptor
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.mock.MockitoSugar
+import org.mockito.Mockito._
-import org.apache.spark.{SparkConf, TestUtils}
+import org.apache.spark._
import org.apache.spark.util.Utils
-class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
+class ExecutorClassLoaderSuite
+ extends FunSuite
+ with BeforeAndAfterAll
+ with MockitoSugar
+ with Logging {
val childClassNames = List("ReplFakeClass1", "ReplFakeClass2")
val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3")
@@ -34,6 +46,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
var tempDir2: File = _
var url1: String = _
var urls2: Array[URL] = _
+ var classServer: HttpServer = _
override def beforeAll() {
super.beforeAll()
@@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
override def afterAll() {
super.afterAll()
+ if (classServer != null) {
+ classServer.stop()
+ }
Utils.deleteRecursively(tempDir1)
Utils.deleteRecursively(tempDir2)
+ SparkEnv.set(null)
}
test("child first") {
@@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
}
}
+ test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") {
+ // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class
+ // from the driver's class server would leak a HTTP connection, causing the class server's
+ // thread / connection pool to be exhausted.
+ val conf = new SparkConf()
+ val securityManager = new SecurityManager(conf)
+ classServer = new HttpServer(conf, tempDir1, securityManager)
+ classServer.start()
+ // ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this
+ val mockEnv = mock[SparkEnv]
+ when(mockEnv.securityManager).thenReturn(securityManager)
+ SparkEnv.set(mockEnv)
+ // Create an ExecutorClassLoader that's configured to load classes from the HTTP server
+ val parentLoader = new URLClassLoader(Array.empty, null)
+ val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false)
+ classLoader.httpUrlConnectionTimeoutMillis = 500
+ // Check that this class loader can actually load classes that exist
+ val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
+ val fakeClassVersion = fakeClass.toString
+ assert(fakeClassVersion === "1")
+ // Try to perform a full GC now, since GC during the test might mask resource leaks
+ System.gc()
+ // When the original bug occurs, the test thread becomes blocked in a classloading call
+ // and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to
+ // shut down the HTTP server when the test times out
+ val interruptor: Interruptor = new Interruptor {
+ override def apply(thread: Thread): Unit = {
+ classServer.stop()
+ classServer = null
+ thread.interrupt()
+ }
+ }
+ def tryAndFailToLoadABunchOfClasses(): Unit = {
+ // The number of trials here should be much larger than Jetty's thread / connection limit
+ // in order to expose thread or connection leaks
+ for (i <- 1 to 1000) {
+ if (Thread.currentThread().isInterrupted) {
+ throw new InterruptedException()
+ }
+ // Incorporate the iteration number into the class name in order to avoid any response
+ // caching that might be added in the future
+ intercept[ClassNotFoundException] {
+ classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance()
+ }
+ }
+ }
+ failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor)
+ }
+
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 15add84878ecf..34fedead44db3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -30,6 +30,12 @@ class AnalysisException protected[sql] (
val startPosition: Option[Int] = None)
extends Exception with Serializable {
+ def withPosition(line: Option[Int], startPosition: Option[Int]) = {
+ val newException = new AnalysisException(message, line, startPosition)
+ newException.setStackTrace(getStackTrace)
+ newException
+ }
+
override def getMessage: String = {
val lineAnnotation = line.map(l => s" line $l").getOrElse("")
val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
index 366be00473d1c..3823584287741 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
@@ -26,7 +26,7 @@ import scala.util.parsing.input.CharArrayReader.EofCh
import org.apache.spark.sql.catalyst.plans.logical._
private[sql] object KeywordNormalizer {
- def apply(str: String) = str.toLowerCase()
+ def apply(str: String): String = str.toLowerCase()
}
private[sql] abstract class AbstractSparkSQLParser
@@ -42,7 +42,7 @@ private[sql] abstract class AbstractSparkSQLParser
}
protected case class Keyword(str: String) {
- def normalize = KeywordNormalizer(str)
+ def normalize: String = KeywordNormalizer(str)
def parser: Parser[String] = normalize
}
@@ -81,7 +81,7 @@ private[sql] abstract class AbstractSparkSQLParser
class SqlLexical extends StdLexical {
case class FloatLit(chars: String) extends Token {
- override def toString = chars
+ override def toString: String = chars
}
/* This is a work around to support the lazy setting */
@@ -120,7 +120,7 @@ class SqlLexical extends StdLexical {
| failure("illegal character")
)
- override def identChar = letter | elem('_')
+ override def identChar: Parser[Elem] = letter | elem('_')
override def whitespace: Parser[Any] =
( whitespaceChar
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 92d3db077c5e1..44eceb0b372e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -64,9 +64,7 @@ class Analyzer(catalog: Catalog,
UnresolvedHavingClauseAttributes ::
TrimGroupingAliases ::
typeCoercionRules ++
- extendedResolutionRules : _*),
- Batch("Remove SubQueries", fixedPoint,
- EliminateSubQueries)
+ extendedResolutionRules : _*)
)
/**
@@ -170,12 +168,12 @@ class Analyzer(catalog: Catalog,
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
object ResolveRelations extends Rule[LogicalPlan] {
- def getTable(u: UnresolvedRelation) = {
+ def getTable(u: UnresolvedRelation): LogicalPlan = {
try {
catalog.lookupRelation(u.tableIdentifier, u.alias)
} catch {
case _: NoSuchTableException =>
- u.failAnalysis(s"no such table ${u.tableIdentifier}")
+ u.failAnalysis(s"no such table ${u.tableName}")
}
}
@@ -275,7 +273,8 @@ class Analyzer(catalog: Catalog,
q.asInstanceOf[GroupingAnalytics].gid
case u @ UnresolvedAttribute(name) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
- val result = q.resolveChildren(name, resolver).getOrElse(u)
+ val result =
+ withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedGetField(child, fieldName) if child.resolved =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index 9e6e2912e0622..5eb7dff0cede8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -86,12 +86,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
tables += ((getDbTableName(tableIdent), plan))
}
- override def unregisterTable(tableIdentifier: Seq[String]) = {
+ override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
val tableIdent = processTableIdentifier(tableIdentifier)
tables -= getDbTableName(tableIdent)
}
- override def unregisterAllTables() = {
+ override def unregisterAllTables(): Unit = {
tables.clear()
}
@@ -147,8 +147,8 @@ trait OverrideCatalog extends Catalog {
}
abstract override def lookupRelation(
- tableIdentifier: Seq[String],
- alias: Option[String] = None): LogicalPlan = {
+ tableIdentifier: Seq[String],
+ alias: Option[String] = None): LogicalPlan = {
val tableIdent = processTableIdentifier(tableIdentifier)
val overriddenTable = overrides.get(getDBTable(tableIdent))
val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r))
@@ -205,15 +205,15 @@ trait OverrideCatalog extends Catalog {
*/
object EmptyCatalog extends Catalog {
- val caseSensitive: Boolean = true
+ override val caseSensitive: Boolean = true
- def tableExists(tableIdentifier: Seq[String]): Boolean = {
+ override def tableExists(tableIdentifier: Seq[String]): Boolean = {
throw new UnsupportedOperationException
}
- def lookupRelation(
- tableIdentifier: Seq[String],
- alias: Option[String] = None) = {
+ override def lookupRelation(
+ tableIdentifier: Seq[String],
+ alias: Option[String] = None): LogicalPlan = {
throw new UnsupportedOperationException
}
@@ -221,11 +221,11 @@ object EmptyCatalog extends Catalog {
throw new UnsupportedOperationException
}
- def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
+ override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}
- def unregisterTable(tableIdentifier: Seq[String]): Unit = {
+ override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
throw new UnsupportedOperationException
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 425e1e41cbf21..40472a1cbb3b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -33,7 +33,7 @@ class CheckAnalysis {
*/
val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil
- def failAnalysis(msg: String) = {
+ def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 9f334f6d42ad1..c43ea55899695 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -35,7 +35,7 @@ trait OverrideFunctionRegistry extends FunctionRegistry {
val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive)
- def registerFunction(name: String, builder: FunctionBuilder) = {
+ override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
functionBuilders.put(name, builder)
}
@@ -47,7 +47,7 @@ trait OverrideFunctionRegistry extends FunctionRegistry {
class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry {
val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive)
- def registerFunction(name: String, builder: FunctionBuilder) = {
+ override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
functionBuilders.put(name, builder)
}
@@ -61,13 +61,15 @@ class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistr
* functions are already filled in and the analyser needs only to resolve attribute references.
*/
object EmptyFunctionRegistry extends FunctionRegistry {
- def registerFunction(name: String, builder: FunctionBuilder) = ???
+ override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
+ throw new UnsupportedOperationException
+ }
- def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
throw new UnsupportedOperationException
}
- def caseSensitive: Boolean = ???
+ override def caseSensitive: Boolean = throw new UnsupportedOperationException
}
/**
@@ -76,7 +78,7 @@ object EmptyFunctionRegistry extends FunctionRegistry {
* TODO move this into util folder?
*/
object StringKeyHashMap {
- def apply[T](caseSensitive: Boolean) = caseSensitive match {
+ def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match {
case false => new StringKeyHashMap[T](_.toLowerCase)
case true => new StringKeyHashMap[T](identity)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index e95f19e69ed43..c61c395cb4bb1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -38,8 +38,16 @@ package object analysis {
implicit class AnalysisErrorAt(t: TreeNode[_]) {
/** Fails the analysis at the point where a specific tree node was parsed. */
- def failAnalysis(msg: String) = {
+ def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg, t.origin.line, t.origin.startPosition)
}
}
+
+ /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */
+ def withPosition[A](t: TreeNode[_])(f: => A) = {
+ try f catch {
+ case a: AnalysisException =>
+ throw a.withPosition(t.origin.line, t.origin.startPosition)
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index a7cd4124e56f3..300e9ba187bc5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.types.DataType
/**
* Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully
@@ -36,7 +37,12 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str
case class UnresolvedRelation(
tableIdentifier: Seq[String],
alias: Option[String] = None) extends LeafNode {
- override def output = Nil
+
+ /** Returns a `.` separated name for this relation. */
+ def tableName: String = tableIdentifier.mkString(".")
+
+ override def output: Seq[Attribute] = Nil
+
override lazy val resolved = false
}
@@ -44,16 +50,16 @@ case class UnresolvedRelation(
* Holds the name of an attribute that has yet to be resolved.
*/
case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
- override def exprId = throw new UnresolvedException(this, "exprId")
- override def dataType = throw new UnresolvedException(this, "dataType")
- override def nullable = throw new UnresolvedException(this, "nullable")
- override def qualifiers = throw new UnresolvedException(this, "qualifiers")
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false
- override def newInstance() = this
- override def withNullability(newNullability: Boolean) = this
- override def withQualifiers(newQualifiers: Seq[String]) = this
- override def withName(newName: String) = UnresolvedAttribute(name)
+ override def newInstance(): UnresolvedAttribute = this
+ override def withNullability(newNullability: Boolean): UnresolvedAttribute = this
+ override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this
+ override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name)
// Unresolved attributes are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType =
@@ -63,16 +69,16 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
}
case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
- override def dataType = throw new UnresolvedException(this, "dataType")
- override def foldable = throw new UnresolvedException(this, "foldable")
- override def nullable = throw new UnresolvedException(this, "nullable")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
// Unresolved functions are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
- override def toString = s"'$name(${children.mkString(",")})"
+ override def toString: String = s"'$name(${children.mkString(",")})"
}
/**
@@ -82,17 +88,17 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
trait Star extends Attribute with trees.LeafNode[Expression] {
self: Product =>
- override def name = throw new UnresolvedException(this, "name")
- override def exprId = throw new UnresolvedException(this, "exprId")
- override def dataType = throw new UnresolvedException(this, "dataType")
- override def nullable = throw new UnresolvedException(this, "nullable")
- override def qualifiers = throw new UnresolvedException(this, "qualifiers")
+ override def name: String = throw new UnresolvedException(this, "name")
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false
- override def newInstance() = this
- override def withNullability(newNullability: Boolean) = this
- override def withQualifiers(newQualifiers: Seq[String]) = this
- override def withName(newName: String) = this
+ override def newInstance(): Star = this
+ override def withNullability(newNullability: Boolean): Star = this
+ override def withQualifiers(newQualifiers: Seq[String]): Star = this
+ override def withName(newName: String): Star = this
// Star gets expanded at runtime so we never evaluate a Star.
override def eval(input: Row = null): EvaluatedType =
@@ -125,7 +131,7 @@ case class UnresolvedStar(table: Option[String]) extends Star {
}
}
- override def toString = table.map(_ + ".").getOrElse("") + "*"
+ override def toString: String = table.map(_ + ".").getOrElse("") + "*"
}
/**
@@ -140,25 +146,25 @@ case class UnresolvedStar(table: Option[String]) extends Star {
case class MultiAlias(child: Expression, names: Seq[String])
extends Attribute with trees.UnaryNode[Expression] {
- override def name = throw new UnresolvedException(this, "name")
+ override def name: String = throw new UnresolvedException(this, "name")
- override def exprId = throw new UnresolvedException(this, "exprId")
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
- override def dataType = throw new UnresolvedException(this, "dataType")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
- override def nullable = throw new UnresolvedException(this, "nullable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
- override def qualifiers = throw new UnresolvedException(this, "qualifiers")
+ override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false
- override def newInstance() = this
+ override def newInstance(): MultiAlias = this
- override def withNullability(newNullability: Boolean) = this
+ override def withNullability(newNullability: Boolean): MultiAlias = this
- override def withQualifiers(newQualifiers: Seq[String]) = this
+ override def withQualifiers(newQualifiers: Seq[String]): MultiAlias = this
- override def withName(newName: String) = this
+ override def withName(newName: String): MultiAlias = this
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
@@ -175,17 +181,17 @@ case class MultiAlias(child: Expression, names: Seq[String])
*/
case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions
- override def toString = expressions.mkString("ResolvedStar(", ", ", ")")
+ override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
}
case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
- override def dataType = throw new UnresolvedException(this, "dataType")
- override def foldable = throw new UnresolvedException(this, "foldable")
- override def nullable = throw new UnresolvedException(this, "nullable")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
- override def toString = s"$child.$fieldName"
+ override def toString: String = s"$child.$fieldName"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 51a09ac0e1249..145f062dd6817 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -61,60 +61,60 @@ package object dsl {
trait ImplicitOperators {
def expr: Expression
- def unary_- = UnaryMinus(expr)
- def unary_! = Not(expr)
- def unary_~ = BitwiseNot(expr)
-
- def + (other: Expression) = Add(expr, other)
- def - (other: Expression) = Subtract(expr, other)
- def * (other: Expression) = Multiply(expr, other)
- def / (other: Expression) = Divide(expr, other)
- def % (other: Expression) = Remainder(expr, other)
- def & (other: Expression) = BitwiseAnd(expr, other)
- def | (other: Expression) = BitwiseOr(expr, other)
- def ^ (other: Expression) = BitwiseXor(expr, other)
-
- def && (other: Expression) = And(expr, other)
- def || (other: Expression) = Or(expr, other)
-
- def < (other: Expression) = LessThan(expr, other)
- def <= (other: Expression) = LessThanOrEqual(expr, other)
- def > (other: Expression) = GreaterThan(expr, other)
- def >= (other: Expression) = GreaterThanOrEqual(expr, other)
- def === (other: Expression) = EqualTo(expr, other)
- def <=> (other: Expression) = EqualNullSafe(expr, other)
- def !== (other: Expression) = Not(EqualTo(expr, other))
-
- def in(list: Expression*) = In(expr, list)
-
- def like(other: Expression) = Like(expr, other)
- def rlike(other: Expression) = RLike(expr, other)
- def contains(other: Expression) = Contains(expr, other)
- def startsWith(other: Expression) = StartsWith(expr, other)
- def endsWith(other: Expression) = EndsWith(expr, other)
- def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
+ def unary_- : Expression= UnaryMinus(expr)
+ def unary_! : Predicate = Not(expr)
+ def unary_~ : Expression = BitwiseNot(expr)
+
+ def + (other: Expression): Expression = Add(expr, other)
+ def - (other: Expression): Expression = Subtract(expr, other)
+ def * (other: Expression): Expression = Multiply(expr, other)
+ def / (other: Expression): Expression = Divide(expr, other)
+ def % (other: Expression): Expression = Remainder(expr, other)
+ def & (other: Expression): Expression = BitwiseAnd(expr, other)
+ def | (other: Expression): Expression = BitwiseOr(expr, other)
+ def ^ (other: Expression): Expression = BitwiseXor(expr, other)
+
+ def && (other: Expression): Predicate = And(expr, other)
+ def || (other: Expression): Predicate = Or(expr, other)
+
+ def < (other: Expression): Predicate = LessThan(expr, other)
+ def <= (other: Expression): Predicate = LessThanOrEqual(expr, other)
+ def > (other: Expression): Predicate = GreaterThan(expr, other)
+ def >= (other: Expression): Predicate = GreaterThanOrEqual(expr, other)
+ def === (other: Expression): Predicate = EqualTo(expr, other)
+ def <=> (other: Expression): Predicate = EqualNullSafe(expr, other)
+ def !== (other: Expression): Predicate = Not(EqualTo(expr, other))
+
+ def in(list: Expression*): Expression = In(expr, list)
+
+ def like(other: Expression): Expression = Like(expr, other)
+ def rlike(other: Expression): Expression = RLike(expr, other)
+ def contains(other: Expression): Expression = Contains(expr, other)
+ def startsWith(other: Expression): Expression = StartsWith(expr, other)
+ def endsWith(other: Expression): Expression = EndsWith(expr, other)
+ def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression =
Substring(expr, pos, len)
- def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
+ def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression =
Substring(expr, pos, len)
- def isNull = IsNull(expr)
- def isNotNull = IsNotNull(expr)
+ def isNull: Predicate = IsNull(expr)
+ def isNotNull: Predicate = IsNotNull(expr)
- def getItem(ordinal: Expression) = GetItem(expr, ordinal)
- def getField(fieldName: String) = UnresolvedGetField(expr, fieldName)
+ def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal)
+ def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName)
- def cast(to: DataType) = Cast(expr, to)
+ def cast(to: DataType): Expression = Cast(expr, to)
- def asc = SortOrder(expr, Ascending)
- def desc = SortOrder(expr, Descending)
+ def asc: SortOrder = SortOrder(expr, Ascending)
+ def desc: SortOrder = SortOrder(expr, Descending)
- def as(alias: String) = Alias(expr, alias)()
- def as(alias: Symbol) = Alias(expr, alias.name)()
+ def as(alias: String): NamedExpression = Alias(expr, alias)()
+ def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()
}
trait ExpressionConversions {
implicit class DslExpression(e: Expression) extends ImplicitOperators {
- def expr = e
+ def expr: Expression = e
}
implicit def booleanToLiteral(b: Boolean): Literal = Literal(b)
@@ -144,94 +144,100 @@ package object dsl {
}
}
- def sum(e: Expression) = Sum(e)
- def sumDistinct(e: Expression) = SumDistinct(e)
- def count(e: Expression) = Count(e)
- def countDistinct(e: Expression*) = CountDistinct(e)
- def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd)
- def avg(e: Expression) = Average(e)
- def first(e: Expression) = First(e)
- def last(e: Expression) = Last(e)
- def min(e: Expression) = Min(e)
- def max(e: Expression) = Max(e)
- def upper(e: Expression) = Upper(e)
- def lower(e: Expression) = Lower(e)
- def sqrt(e: Expression) = Sqrt(e)
- def abs(e: Expression) = Abs(e)
-
- implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
+ def sum(e: Expression): Expression = Sum(e)
+ def sumDistinct(e: Expression): Expression = SumDistinct(e)
+ def count(e: Expression): Expression = Count(e)
+ def countDistinct(e: Expression*): Expression = CountDistinct(e)
+ def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
+ ApproxCountDistinct(e, rsd)
+ def avg(e: Expression): Expression = Average(e)
+ def first(e: Expression): Expression = First(e)
+ def last(e: Expression): Expression = Last(e)
+ def min(e: Expression): Expression = Min(e)
+ def max(e: Expression): Expression = Max(e)
+ def upper(e: Expression): Expression = Upper(e)
+ def lower(e: Expression): Expression = Lower(e)
+ def sqrt(e: Expression): Expression = Sqrt(e)
+ def abs(e: Expression): Expression = Abs(e)
+
+ implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators {
override def expr: Expression = Literal(s)
- def attr = analysis.UnresolvedAttribute(s)
+ def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s)
}
abstract class ImplicitAttribute extends ImplicitOperators {
def s: String
- def expr = attr
- def attr = analysis.UnresolvedAttribute(s)
+ def expr: UnresolvedAttribute = attr
+ def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s)
/** Creates a new AttributeReference of type boolean */
- def boolean = AttributeReference(s, BooleanType, nullable = true)()
+ def boolean: AttributeReference = AttributeReference(s, BooleanType, nullable = true)()
/** Creates a new AttributeReference of type byte */
- def byte = AttributeReference(s, ByteType, nullable = true)()
+ def byte: AttributeReference = AttributeReference(s, ByteType, nullable = true)()
/** Creates a new AttributeReference of type short */
- def short = AttributeReference(s, ShortType, nullable = true)()
+ def short: AttributeReference = AttributeReference(s, ShortType, nullable = true)()
/** Creates a new AttributeReference of type int */
- def int = AttributeReference(s, IntegerType, nullable = true)()
+ def int: AttributeReference = AttributeReference(s, IntegerType, nullable = true)()
/** Creates a new AttributeReference of type long */
- def long = AttributeReference(s, LongType, nullable = true)()
+ def long: AttributeReference = AttributeReference(s, LongType, nullable = true)()
/** Creates a new AttributeReference of type float */
- def float = AttributeReference(s, FloatType, nullable = true)()
+ def float: AttributeReference = AttributeReference(s, FloatType, nullable = true)()
/** Creates a new AttributeReference of type double */
- def double = AttributeReference(s, DoubleType, nullable = true)()
+ def double: AttributeReference = AttributeReference(s, DoubleType, nullable = true)()
/** Creates a new AttributeReference of type string */
- def string = AttributeReference(s, StringType, nullable = true)()
+ def string: AttributeReference = AttributeReference(s, StringType, nullable = true)()
/** Creates a new AttributeReference of type date */
- def date = AttributeReference(s, DateType, nullable = true)()
+ def date: AttributeReference = AttributeReference(s, DateType, nullable = true)()
/** Creates a new AttributeReference of type decimal */
- def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)()
+ def decimal: AttributeReference =
+ AttributeReference(s, DecimalType.Unlimited, nullable = true)()
/** Creates a new AttributeReference of type decimal */
- def decimal(precision: Int, scale: Int) =
+ def decimal(precision: Int, scale: Int): AttributeReference =
AttributeReference(s, DecimalType(precision, scale), nullable = true)()
/** Creates a new AttributeReference of type timestamp */
- def timestamp = AttributeReference(s, TimestampType, nullable = true)()
+ def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)()
/** Creates a new AttributeReference of type binary */
- def binary = AttributeReference(s, BinaryType, nullable = true)()
+ def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)()
/** Creates a new AttributeReference of type array */
- def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)()
+ def array(dataType: DataType): AttributeReference =
+ AttributeReference(s, ArrayType(dataType), nullable = true)()
/** Creates a new AttributeReference of type map */
def map(keyType: DataType, valueType: DataType): AttributeReference =
map(MapType(keyType, valueType))
- def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)()
+
+ def map(mapType: MapType): AttributeReference =
+ AttributeReference(s, mapType, nullable = true)()
/** Creates a new AttributeReference of type struct */
def struct(fields: StructField*): AttributeReference = struct(StructType(fields))
- def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)()
+ def struct(structType: StructType): AttributeReference =
+ AttributeReference(s, structType, nullable = true)()
}
implicit class DslAttribute(a: AttributeReference) {
- def notNull = a.withNullability(false)
- def nullable = a.withNullability(true)
+ def notNull: AttributeReference = a.withNullability(false)
+ def nullable: AttributeReference = a.withNullability(true)
// Protobuf terminology
- def required = a.withNullability(false)
+ def required: AttributeReference = a.withNullability(false)
- def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable)
+ def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable)
}
}
@@ -241,23 +247,23 @@ package object dsl {
abstract class LogicalPlanFunctions {
def logicalPlan: LogicalPlan
- def select(exprs: NamedExpression*) = Project(exprs, logicalPlan)
+ def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan)
- def where(condition: Expression) = Filter(condition, logicalPlan)
+ def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
- def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan)
+ def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)
def join(
otherPlan: LogicalPlan,
joinType: JoinType = Inner,
- condition: Option[Expression] = None) =
+ condition: Option[Expression] = None): LogicalPlan =
Join(logicalPlan, otherPlan, joinType, condition)
- def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, true, logicalPlan)
+ def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan)
- def sortBy(sortExprs: SortOrder*) = Sort(sortExprs, false, logicalPlan)
+ def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan)
- def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = {
+ def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = {
val aliasedExprs = aggregateExprs.map {
case ne: NamedExpression => ne
case e => Alias(e, e.toString)()
@@ -265,41 +271,43 @@ package object dsl {
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
}
- def subquery(alias: Symbol) = Subquery(alias.name, logicalPlan)
+ def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan)
- def unionAll(otherPlan: LogicalPlan) = Union(logicalPlan, otherPlan)
+ def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)
- def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) =
+ def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan =
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)
def sample(
fraction: Double,
withReplacement: Boolean = true,
- seed: Int = (math.random * 1000).toInt) =
+ seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)
def generate(
generator: Generator,
join: Boolean = false,
outer: Boolean = false,
- alias: Option[String] = None) =
+ alias: Option[String] = None): LogicalPlan =
Generate(generator, join, outer, None, logicalPlan)
- def insertInto(tableName: String, overwrite: Boolean = false) =
+ def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite)
- def analyze = analysis.SimpleAnalyzer(logicalPlan)
+ def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan))
}
object plans { // scalastyle:ignore
implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions {
- def writeToFile(path: String) = WriteToFile(path, logicalPlan)
+ def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan)
}
}
case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) {
- def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args)
+ def call(args: Expression*): ScalaUdf = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args)
+ }
}
// scalastyle:off
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 82e760b6c6916..96a11e352ec50 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -23,7 +23,9 @@ package org.apache.spark.sql.catalyst.expressions
* of the name, or the expected nullability).
*/
object AttributeMap {
- def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
+ def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
+ new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
+ }
}
class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index adaeab0b5c027..f9ae85a5cfc1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -19,27 +19,27 @@ package org.apache.spark.sql.catalyst.expressions
protected class AttributeEquals(val a: Attribute) {
- override def hashCode() = a match {
+ override def hashCode(): Int = a match {
case ar: AttributeReference => ar.exprId.hashCode()
case a => a.hashCode()
}
- override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match {
+ override def equals(other: Any): Boolean = (a, other.asInstanceOf[AttributeEquals].a) match {
case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
case (a1, a2) => a1 == a2
}
}
object AttributeSet {
- def apply(a: Attribute) =
- new AttributeSet(Set(new AttributeEquals(a)))
+ def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
- def apply(baseSet: Seq[Expression]) =
+ def apply(baseSet: Seq[Expression]): AttributeSet = {
new AttributeSet(
baseSet
.flatMap(_.references)
.map(new AttributeEquals(_)).toSet)
+ }
}
/**
@@ -57,7 +57,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
extends Traversable[Attribute] with Serializable {
/** Returns true if the members of this AttributeSet and other are the same. */
- override def equals(other: Any) = other match {
+ override def equals(other: Any): Boolean = other match {
case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains)
case _ => false
}
@@ -81,32 +81,34 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
* Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in
* `other`.
*/
- def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet)
+ def subsetOf(other: AttributeSet): Boolean = baseSet.subsetOf(other.baseSet)
/**
* Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found
* in `other`.
*/
- def --(other: Traversable[NamedExpression]) =
+ def --(other: Traversable[NamedExpression]): AttributeSet =
new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
/**
* Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found
* in `other`.
*/
- def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet)
+ def ++(other: AttributeSet): AttributeSet = new AttributeSet(baseSet ++ other.baseSet)
/**
* Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to
* true.
*/
- override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a)))
+ override def filter(f: Attribute => Boolean): AttributeSet =
+ new AttributeSet(baseSet.filter(ae => f(ae.a)))
/**
* Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in
* `this` and `other`.
*/
- def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet))
+ def intersect(other: AttributeSet): AttributeSet =
+ new AttributeSet(baseSet.intersect(other.baseSet))
override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 76a9f08dea85f..2225621dbaabd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -32,7 +32,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
type EvaluatedType = Any
- override def toString = s"input[$ordinal]"
+ override def toString: String = s"input[$ordinal]"
override def eval(input: Row): Any = input(ordinal)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index b1bc858478ee1..31f1a5fdc7e53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -29,9 +29,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
- override def foldable = child.foldable
+ override def foldable: Boolean = child.foldable
- override def nullable = forceNullable(child.dataType, dataType) || child.nullable
+ override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable
private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match {
case (StringType, _: NumericType) => true
@@ -103,7 +103,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
}
- override def toString = s"CAST($child, $dataType)"
+ override def toString: String = s"CAST($child, $dataType)"
type EvaluatedType = Any
@@ -394,10 +394,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val casts = from.fields.zip(to.fields).map {
case (fromField, toField) => cast(fromField.dataType, toField.dataType)
}
- // TODO: This is very slow!
- buildCast[Row](_, row => Row(row.toSeq.zip(casts).map {
- case (v, cast) => if (v == null) null else cast(v)
- }: _*))
+ // TODO: Could be faster?
+ val newRow = new GenericMutableRow(from.fields.size)
+ buildCast[Row](_, row => {
+ var i = 0
+ while (i < row.length) {
+ val v = row(i)
+ newRow.update(i, if (v == null) null else casts(i)(v))
+ i += 1
+ }
+ newRow.copy()
+ })
}
private[this] def cast(from: DataType, to: DataType): Any => Any = to match {
@@ -430,14 +437,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
object Cast {
// `SimpleDateFormat` is not thread-safe.
private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] {
- override def initialValue() = {
+ override def initialValue(): SimpleDateFormat = {
new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
}
}
// `SimpleDateFormat` is not thread-safe.
private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] {
- override def initialValue() = {
+ override def initialValue(): SimpleDateFormat = {
new SimpleDateFormat("yyyy-MM-dd")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 6ad39b8372cfb..4e3bbc06a5b4c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -65,7 +65,7 @@ abstract class Expression extends TreeNode[Expression] {
* Returns true if all the children of this expression have been resolved to a specific schema
* and false if any still contains any unresolved placeholders.
*/
- def childrenResolved = !children.exists(!_.resolved)
+ def childrenResolved: Boolean = !children.exists(!_.resolved)
/**
* Returns a string representation of this expression that does not have developer centric
@@ -84,9 +84,9 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
def symbol: String
- override def foldable = left.foldable && right.foldable
+ override def foldable: Boolean = left.foldable && right.foldable
- override def toString = s"($left $symbol $right)"
+ override def toString: String = s"($left $symbol $right)"
}
abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
@@ -104,8 +104,8 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
case class GroupExpression(children: Seq[Expression]) extends Expression {
self: Product =>
type EvaluatedType = Seq[Any]
- override def eval(input: Row): EvaluatedType = ???
- override def nullable = false
- override def foldable = false
- override def dataType = ???
+ override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException
+ override def nullable: Boolean = false
+ override def foldable: Boolean = false
+ override def dataType: DataType = throw new UnsupportedOperationException
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index db5d897ee569f..c2866cd955409 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -40,7 +40,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
new GenericRow(outputArray)
}
- override def toString = s"Row => [${exprArray.mkString(",")}]"
+ override def toString: String = s"Row => [${exprArray.mkString(",")}]"
}
/**
@@ -107,12 +107,12 @@ class JoinedRow extends Row {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length = row1.length + row2.length
+ override def length: Int = row1.length + row2.length
- override def apply(i: Int) =
+ override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)
- override def isNullAt(i: Int) =
+ override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
override def getInt(i: Int): Int =
@@ -142,7 +142,7 @@ class JoinedRow extends Row {
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- override def copy() = {
+ override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
@@ -153,7 +153,7 @@ class JoinedRow extends Row {
new GenericRow(copiedValues)
}
- override def toString() = {
+ override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
@@ -207,12 +207,12 @@ class JoinedRow2 extends Row {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length = row1.length + row2.length
+ override def length: Int = row1.length + row2.length
- override def apply(i: Int) =
+ override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)
- override def isNullAt(i: Int) =
+ override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
override def getInt(i: Int): Int =
@@ -242,7 +242,7 @@ class JoinedRow2 extends Row {
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- override def copy() = {
+ override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
@@ -253,7 +253,7 @@ class JoinedRow2 extends Row {
new GenericRow(copiedValues)
}
- override def toString() = {
+ override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
@@ -301,12 +301,12 @@ class JoinedRow3 extends Row {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length = row1.length + row2.length
+ override def length: Int = row1.length + row2.length
- override def apply(i: Int) =
+ override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)
- override def isNullAt(i: Int) =
+ override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
override def getInt(i: Int): Int =
@@ -336,7 +336,7 @@ class JoinedRow3 extends Row {
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- override def copy() = {
+ override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
@@ -347,7 +347,7 @@ class JoinedRow3 extends Row {
new GenericRow(copiedValues)
}
- override def toString() = {
+ override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
@@ -395,12 +395,12 @@ class JoinedRow4 extends Row {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length = row1.length + row2.length
+ override def length: Int = row1.length + row2.length
- override def apply(i: Int) =
+ override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)
- override def isNullAt(i: Int) =
+ override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
override def getInt(i: Int): Int =
@@ -430,7 +430,7 @@ class JoinedRow4 extends Row {
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- override def copy() = {
+ override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
@@ -441,7 +441,7 @@ class JoinedRow4 extends Row {
new GenericRow(copiedValues)
}
- override def toString() = {
+ override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
@@ -489,12 +489,12 @@ class JoinedRow5 extends Row {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length = row1.length + row2.length
+ override def length: Int = row1.length + row2.length
- override def apply(i: Int) =
+ override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)
- override def isNullAt(i: Int) =
+ override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
override def getInt(i: Int): Int =
@@ -524,7 +524,7 @@ class JoinedRow5 extends Row {
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- override def copy() = {
+ override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
@@ -535,7 +535,7 @@ class JoinedRow5 extends Row {
new GenericRow(copiedValues)
}
- override def toString() = {
+ override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
index b2c6d3029031d..f5fea3f015dc4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
@@ -18,16 +18,19 @@
package org.apache.spark.sql.catalyst.expressions
import java.util.Random
-import org.apache.spark.sql.types.DoubleType
+
+import org.apache.spark.sql.types.{DataType, DoubleType}
case object Rand extends LeafExpression {
- override def dataType = DoubleType
- override def nullable = false
+ override def dataType: DataType = DoubleType
+ override def nullable: Boolean = false
private[this] lazy val rand = new Random
- override def eval(input: Row = null) = rand.nextDouble().asInstanceOf[EvaluatedType]
+ override def eval(input: Row = null): EvaluatedType = {
+ rand.nextDouble().asInstanceOf[EvaluatedType]
+ }
- override def toString = "RAND()"
+ override def toString: String = "RAND()"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 8a36c6810790d..389dc4f745723 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -29,9 +29,9 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
type EvaluatedType = Any
- def nullable = true
+ override def nullable: Boolean = true
- override def toString = s"scalaUDF(${children.mkString(",")})"
+ override def toString: String = s"scalaUDF(${children.mkString(",")})"
// scalastyle:off
@@ -39,363 +39,669 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
(1 to 22).map { x =>
val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _)
- val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _)
-
- s"""
- case $x =>
- function.asInstanceOf[($anys) => Any](
- $evals)
- """
+ val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _)
+ val evals = (0 to x - 1).map(x => s"ScalaReflection.convertToScala(child$x.eval(input), child$x.dataType)").reduce(_ + ",\n " + _)
+
+ s""" case $x =>
+ val func = function.asInstanceOf[($anys) => Any]
+ $childs
+ (input: Row) => {
+ func(
+ $evals)
+ }
+ """
}.foreach(println)
*/
-
- override def eval(input: Row): Any = {
- val result = children.size match {
- case 0 => function.asInstanceOf[() => Any]()
- case 1 =>
- function.asInstanceOf[(Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType))
-
-
- case 2 =>
- function.asInstanceOf[(Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType))
-
-
- case 3 =>
- function.asInstanceOf[(Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType))
-
-
- case 4 =>
- function.asInstanceOf[(Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType))
-
-
- case 5 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType))
-
-
- case 6 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType))
-
-
- case 7 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType))
-
-
- case 8 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType))
-
-
- case 9 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType))
-
-
- case 10 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType))
-
-
- case 11 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType))
-
-
- case 12 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType))
-
-
- case 13 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
- ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType))
-
-
- case 14 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
- ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
- ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType))
-
-
- case 15 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
- ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
- ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
- ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType))
-
-
- case 16 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
- ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
- ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
- ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
- ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType))
-
-
- case 17 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
- ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
- ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
- ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
- ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
- ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType))
-
-
- case 18 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
- ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
- ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
- ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
- ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
- ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType),
- ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType))
-
-
- case 19 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
- ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
- ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
- ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
- ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
- ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType),
- ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType),
- ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType))
-
-
- case 20 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
- ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
- ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
- ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
- ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
- ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType),
- ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType),
- ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType),
- ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType))
-
-
- case 21 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
- ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
- ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
- ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
- ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
- ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType),
- ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType),
- ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType),
- ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType),
- ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType))
-
-
- case 22 =>
- function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
- ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
- ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
- ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
- ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
- ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
- ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
- ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
- ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
- ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
- ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
- ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
- ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
- ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
- ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
- ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
- ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType),
- ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType),
- ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType),
- ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType),
- ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType),
- ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType))
-
- }
- // scalastyle:on
-
- ScalaReflection.convertToCatalyst(result, dataType)
+
+ val f = children.size match {
+ case 0 =>
+ val func = function.asInstanceOf[() => Any]
+ (input: Row) => {
+ func()
+ }
+
+ case 1 =>
+ val func = function.asInstanceOf[(Any) => Any]
+ val child0 = children(0)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType))
+ }
+
+ case 2 =>
+ val func = function.asInstanceOf[(Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType))
+ }
+
+ case 3 =>
+ val func = function.asInstanceOf[(Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType))
+ }
+
+ case 4 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType))
+ }
+
+ case 5 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType))
+ }
+
+ case 6 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType))
+ }
+
+ case 7 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType))
+ }
+
+ case 8 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType))
+ }
+
+ case 9 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType))
+ }
+
+ case 10 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType))
+ }
+
+ case 11 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType))
+ }
+
+ case 12 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType))
+ }
+
+ case 13 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ val child12 = children(12)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType),
+ ScalaReflection.convertToScala(child12.eval(input), child12.dataType))
+ }
+
+ case 14 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ val child12 = children(12)
+ val child13 = children(13)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType),
+ ScalaReflection.convertToScala(child12.eval(input), child12.dataType),
+ ScalaReflection.convertToScala(child13.eval(input), child13.dataType))
+ }
+
+ case 15 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ val child12 = children(12)
+ val child13 = children(13)
+ val child14 = children(14)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType),
+ ScalaReflection.convertToScala(child12.eval(input), child12.dataType),
+ ScalaReflection.convertToScala(child13.eval(input), child13.dataType),
+ ScalaReflection.convertToScala(child14.eval(input), child14.dataType))
+ }
+
+ case 16 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ val child12 = children(12)
+ val child13 = children(13)
+ val child14 = children(14)
+ val child15 = children(15)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType),
+ ScalaReflection.convertToScala(child12.eval(input), child12.dataType),
+ ScalaReflection.convertToScala(child13.eval(input), child13.dataType),
+ ScalaReflection.convertToScala(child14.eval(input), child14.dataType),
+ ScalaReflection.convertToScala(child15.eval(input), child15.dataType))
+ }
+
+ case 17 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ val child12 = children(12)
+ val child13 = children(13)
+ val child14 = children(14)
+ val child15 = children(15)
+ val child16 = children(16)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType),
+ ScalaReflection.convertToScala(child12.eval(input), child12.dataType),
+ ScalaReflection.convertToScala(child13.eval(input), child13.dataType),
+ ScalaReflection.convertToScala(child14.eval(input), child14.dataType),
+ ScalaReflection.convertToScala(child15.eval(input), child15.dataType),
+ ScalaReflection.convertToScala(child16.eval(input), child16.dataType))
+ }
+
+ case 18 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ val child12 = children(12)
+ val child13 = children(13)
+ val child14 = children(14)
+ val child15 = children(15)
+ val child16 = children(16)
+ val child17 = children(17)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType),
+ ScalaReflection.convertToScala(child12.eval(input), child12.dataType),
+ ScalaReflection.convertToScala(child13.eval(input), child13.dataType),
+ ScalaReflection.convertToScala(child14.eval(input), child14.dataType),
+ ScalaReflection.convertToScala(child15.eval(input), child15.dataType),
+ ScalaReflection.convertToScala(child16.eval(input), child16.dataType),
+ ScalaReflection.convertToScala(child17.eval(input), child17.dataType))
+ }
+
+ case 19 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ val child12 = children(12)
+ val child13 = children(13)
+ val child14 = children(14)
+ val child15 = children(15)
+ val child16 = children(16)
+ val child17 = children(17)
+ val child18 = children(18)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType),
+ ScalaReflection.convertToScala(child12.eval(input), child12.dataType),
+ ScalaReflection.convertToScala(child13.eval(input), child13.dataType),
+ ScalaReflection.convertToScala(child14.eval(input), child14.dataType),
+ ScalaReflection.convertToScala(child15.eval(input), child15.dataType),
+ ScalaReflection.convertToScala(child16.eval(input), child16.dataType),
+ ScalaReflection.convertToScala(child17.eval(input), child17.dataType),
+ ScalaReflection.convertToScala(child18.eval(input), child18.dataType))
+ }
+
+ case 20 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ val child12 = children(12)
+ val child13 = children(13)
+ val child14 = children(14)
+ val child15 = children(15)
+ val child16 = children(16)
+ val child17 = children(17)
+ val child18 = children(18)
+ val child19 = children(19)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType),
+ ScalaReflection.convertToScala(child12.eval(input), child12.dataType),
+ ScalaReflection.convertToScala(child13.eval(input), child13.dataType),
+ ScalaReflection.convertToScala(child14.eval(input), child14.dataType),
+ ScalaReflection.convertToScala(child15.eval(input), child15.dataType),
+ ScalaReflection.convertToScala(child16.eval(input), child16.dataType),
+ ScalaReflection.convertToScala(child17.eval(input), child17.dataType),
+ ScalaReflection.convertToScala(child18.eval(input), child18.dataType),
+ ScalaReflection.convertToScala(child19.eval(input), child19.dataType))
+ }
+
+ case 21 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ val child12 = children(12)
+ val child13 = children(13)
+ val child14 = children(14)
+ val child15 = children(15)
+ val child16 = children(16)
+ val child17 = children(17)
+ val child18 = children(18)
+ val child19 = children(19)
+ val child20 = children(20)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType),
+ ScalaReflection.convertToScala(child12.eval(input), child12.dataType),
+ ScalaReflection.convertToScala(child13.eval(input), child13.dataType),
+ ScalaReflection.convertToScala(child14.eval(input), child14.dataType),
+ ScalaReflection.convertToScala(child15.eval(input), child15.dataType),
+ ScalaReflection.convertToScala(child16.eval(input), child16.dataType),
+ ScalaReflection.convertToScala(child17.eval(input), child17.dataType),
+ ScalaReflection.convertToScala(child18.eval(input), child18.dataType),
+ ScalaReflection.convertToScala(child19.eval(input), child19.dataType),
+ ScalaReflection.convertToScala(child20.eval(input), child20.dataType))
+ }
+
+ case 22 =>
+ val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]
+ val child0 = children(0)
+ val child1 = children(1)
+ val child2 = children(2)
+ val child3 = children(3)
+ val child4 = children(4)
+ val child5 = children(5)
+ val child6 = children(6)
+ val child7 = children(7)
+ val child8 = children(8)
+ val child9 = children(9)
+ val child10 = children(10)
+ val child11 = children(11)
+ val child12 = children(12)
+ val child13 = children(13)
+ val child14 = children(14)
+ val child15 = children(15)
+ val child16 = children(16)
+ val child17 = children(17)
+ val child18 = children(18)
+ val child19 = children(19)
+ val child20 = children(20)
+ val child21 = children(21)
+ (input: Row) => {
+ func(
+ ScalaReflection.convertToScala(child0.eval(input), child0.dataType),
+ ScalaReflection.convertToScala(child1.eval(input), child1.dataType),
+ ScalaReflection.convertToScala(child2.eval(input), child2.dataType),
+ ScalaReflection.convertToScala(child3.eval(input), child3.dataType),
+ ScalaReflection.convertToScala(child4.eval(input), child4.dataType),
+ ScalaReflection.convertToScala(child5.eval(input), child5.dataType),
+ ScalaReflection.convertToScala(child6.eval(input), child6.dataType),
+ ScalaReflection.convertToScala(child7.eval(input), child7.dataType),
+ ScalaReflection.convertToScala(child8.eval(input), child8.dataType),
+ ScalaReflection.convertToScala(child9.eval(input), child9.dataType),
+ ScalaReflection.convertToScala(child10.eval(input), child10.dataType),
+ ScalaReflection.convertToScala(child11.eval(input), child11.dataType),
+ ScalaReflection.convertToScala(child12.eval(input), child12.dataType),
+ ScalaReflection.convertToScala(child13.eval(input), child13.dataType),
+ ScalaReflection.convertToScala(child14.eval(input), child14.dataType),
+ ScalaReflection.convertToScala(child15.eval(input), child15.dataType),
+ ScalaReflection.convertToScala(child16.eval(input), child16.dataType),
+ ScalaReflection.convertToScala(child17.eval(input), child17.dataType),
+ ScalaReflection.convertToScala(child18.eval(input), child18.dataType),
+ ScalaReflection.convertToScala(child19.eval(input), child19.dataType),
+ ScalaReflection.convertToScala(child20.eval(input), child20.dataType),
+ ScalaReflection.convertToScala(child21.eval(input), child21.dataType))
+ }
}
+
+ // scalastyle:on
+
+ override def eval(input: Row): Any = ScalaReflection.convertToCatalyst(f(input), dataType)
+
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index d00b2ac09745c..83074eb1e6310 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.types.DataType
abstract sealed class SortDirection
case object Ascending extends SortDirection
@@ -31,12 +32,12 @@ case object Descending extends SortDirection
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
with trees.UnaryNode[Expression] {
- override def dataType = child.dataType
- override def nullable = child.nullable
+ override def dataType: DataType = child.dataType
+ override def nullable: Boolean = child.nullable
// SortOrder itself is never evaluated.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
- override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
+ override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 21d714c9a8c3b..47b6f358ed1b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -62,126 +62,126 @@ abstract class MutableValue extends Serializable {
var isNull: Boolean = true
def boxed: Any
def update(v: Any)
- def copy(): this.type
+ def copy(): MutableValue
}
final class MutableInt extends MutableValue {
var value: Int = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = {
isNull = false
- v.asInstanceOf[Int]
+ value = v.asInstanceOf[Int]
}
- def copy() = {
+ override def copy(): MutableInt = {
val newCopy = new MutableInt
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableInt]
}
}
final class MutableFloat extends MutableValue {
var value: Float = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = {
isNull = false
- v.asInstanceOf[Float]
+ value = v.asInstanceOf[Float]
}
- def copy() = {
+ override def copy(): MutableFloat = {
val newCopy = new MutableFloat
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableFloat]
}
}
final class MutableBoolean extends MutableValue {
var value: Boolean = false
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = {
isNull = false
- v.asInstanceOf[Boolean]
+ value = v.asInstanceOf[Boolean]
}
- def copy() = {
+ override def copy(): MutableBoolean = {
val newCopy = new MutableBoolean
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableBoolean]
}
}
final class MutableDouble extends MutableValue {
var value: Double = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = {
isNull = false
- v.asInstanceOf[Double]
+ value = v.asInstanceOf[Double]
}
- def copy() = {
+ override def copy(): MutableDouble = {
val newCopy = new MutableDouble
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableDouble]
}
}
final class MutableShort extends MutableValue {
var value: Short = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = value = {
isNull = false
v.asInstanceOf[Short]
}
- def copy() = {
+ override def copy(): MutableShort = {
val newCopy = new MutableShort
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableShort]
}
}
final class MutableLong extends MutableValue {
var value: Long = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = value = {
isNull = false
v.asInstanceOf[Long]
}
- def copy() = {
+ override def copy(): MutableLong = {
val newCopy = new MutableLong
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableLong]
}
}
final class MutableByte extends MutableValue {
var value: Byte = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = value = {
isNull = false
v.asInstanceOf[Byte]
}
- def copy() = {
+ override def copy(): MutableByte = {
val newCopy = new MutableByte
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableByte]
}
}
final class MutableAny extends MutableValue {
var value: Any = _
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = {
isNull = false
- v.asInstanceOf[Any]
+ value = v.asInstanceOf[Any]
}
- def copy() = {
+ override def copy(): MutableAny = {
val newCopy = new MutableAny
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableAny]
}
}
@@ -234,9 +234,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
}
- override def setString(ordinal: Int, value: String) = update(ordinal, value)
+ override def setString(ordinal: Int, value: String): Unit = update(ordinal, value)
- override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
+ override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String]
override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 5297d1e31246c..30da4faa3f1c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -79,27 +79,29 @@ abstract class AggregateFunction
/** Base should return the generic aggregate expression that this function is computing */
val base: AggregateExpression
- override def nullable = base.nullable
- override def dataType = base.dataType
+ override def nullable: Boolean = base.nullable
+ override def dataType: DataType = base.dataType
def update(input: Row): Unit
// Do we really need this?
- override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
+ override def newInstance(): AggregateFunction = {
+ makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
+ }
}
case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = true
- override def dataType = child.dataType
- override def toString = s"MIN($child)"
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"MIN($child)"
override def asPartial: SplitEvaluation = {
val partialMin = Alias(Min(child), "PartialMin")()
SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
}
- override def newInstance() = new MinFunction(child, this)
+ override def newInstance(): MinFunction = new MinFunction(child, this)
}
case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -121,16 +123,16 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = true
- override def dataType = child.dataType
- override def toString = s"MAX($child)"
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"MAX($child)"
override def asPartial: SplitEvaluation = {
val partialMax = Alias(Max(child), "PartialMax")()
SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
}
- override def newInstance() = new MaxFunction(child, this)
+ override def newInstance(): MaxFunction = new MaxFunction(child, this)
}
case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -152,29 +154,29 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = false
- override def dataType = LongType
- override def toString = s"COUNT($child)"
+ override def nullable: Boolean = false
+ override def dataType: LongType.type = LongType
+ override def toString: String = s"COUNT($child)"
override def asPartial: SplitEvaluation = {
val partialCount = Alias(Count(child), "PartialCount")()
SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil)
}
- override def newInstance() = new CountFunction(child, this)
+ override def newInstance(): CountFunction = new CountFunction(child, this)
}
case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
def this() = this(null)
- override def children = expressions
+ override def children: Seq[Expression] = expressions
- override def nullable = false
- override def dataType = LongType
- override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
- override def newInstance() = new CountDistinctFunction(expressions, this)
+ override def nullable: Boolean = false
+ override def dataType: DataType = LongType
+ override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})"
+ override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this)
- override def asPartial = {
+ override def asPartial: SplitEvaluation = {
val partialSet = Alias(CollectHashSet(expressions), "partialSets")()
SplitEvaluation(
CombineSetsAndCount(partialSet.toAttribute),
@@ -185,11 +187,11 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
def this() = this(null)
- override def children = expressions
- override def nullable = false
- override def dataType = ArrayType(expressions.head.dataType)
- override def toString = s"AddToHashSet(${expressions.mkString(",")})"
- override def newInstance() = new CollectHashSetFunction(expressions, this)
+ override def children: Seq[Expression] = expressions
+ override def nullable: Boolean = false
+ override def dataType: ArrayType = ArrayType(expressions.head.dataType)
+ override def toString: String = s"AddToHashSet(${expressions.mkString(",")})"
+ override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this)
}
case class CollectHashSetFunction(
@@ -219,11 +221,13 @@ case class CollectHashSetFunction(
case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression {
def this() = this(null)
- override def children = inputSet :: Nil
- override def nullable = false
- override def dataType = LongType
- override def toString = s"CombineAndCount($inputSet)"
- override def newInstance() = new CombineSetsAndCountFunction(inputSet, this)
+ override def children: Seq[Expression] = inputSet :: Nil
+ override def nullable: Boolean = false
+ override def dataType: DataType = LongType
+ override def toString: String = s"CombineAndCount($inputSet)"
+ override def newInstance(): CombineSetsAndCountFunction = {
+ new CombineSetsAndCountFunction(inputSet, this)
+ }
}
case class CombineSetsAndCountFunction(
@@ -249,27 +253,31 @@ case class CombineSetsAndCountFunction(
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def nullable = false
- override def dataType = child.dataType
- override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
- override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
+ override def nullable: Boolean = false
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
+ override def newInstance(): ApproxCountDistinctPartitionFunction = {
+ new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
+ }
}
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def nullable = false
- override def dataType = LongType
- override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
- override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD)
+ override def nullable: Boolean = false
+ override def dataType: LongType.type = LongType
+ override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
+ override def newInstance(): ApproxCountDistinctMergeFunction = {
+ new ApproxCountDistinctMergeFunction(child, this, relativeSD)
+ }
}
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = false
- override def dataType = LongType
- override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
+ override def nullable: Boolean = false
+ override def dataType: LongType.type = LongType
+ override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
override def asPartial: SplitEvaluation = {
val partialCount =
@@ -280,14 +288,14 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
partialCount :: Nil)
}
- override def newInstance() = new CountDistinctFunction(child :: Nil, this)
+ override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
}
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = true
+ override def nullable: Boolean = true
- override def dataType = child.dataType match {
+ override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive
case DecimalType.Unlimited =>
@@ -296,7 +304,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
DoubleType
}
- override def toString = s"AVG($child)"
+ override def toString: String = s"AVG($child)"
override def asPartial: SplitEvaluation = {
child.dataType match {
@@ -323,14 +331,14 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
}
}
- override def newInstance() = new AverageFunction(child, this)
+ override def newInstance(): AverageFunction = new AverageFunction(child, this)
}
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = true
+ override def nullable: Boolean = true
- override def dataType = child.dataType match {
+ override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
case DecimalType.Unlimited =>
@@ -339,7 +347,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
child.dataType
}
- override def toString = s"SUM($child)"
+ override def toString: String = s"SUM($child)"
override def asPartial: SplitEvaluation = {
child.dataType match {
@@ -357,7 +365,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
}
}
- override def newInstance() = new SumFunction(child, this)
+ override def newInstance(): SumFunction = new SumFunction(child, this)
}
/**
@@ -377,19 +385,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class CombineSum(child: Expression) extends AggregateExpression {
def this() = this(null)
- override def children = child :: Nil
- override def nullable = true
- override def dataType = child.dataType
- override def toString = s"CombineSum($child)"
- override def newInstance() = new CombineSumFunction(child, this)
+ override def children: Seq[Expression] = child :: Nil
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"CombineSum($child)"
+ override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
}
case class SumDistinct(child: Expression)
extends PartialAggregate with trees.UnaryNode[Expression] {
def this() = this(null)
- override def nullable = true
- override def dataType = child.dataType match {
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
case DecimalType.Unlimited =>
@@ -397,10 +405,10 @@ case class SumDistinct(child: Expression)
case _ =>
child.dataType
}
- override def toString = s"SUM(DISTINCT ${child})"
- override def newInstance() = new SumDistinctFunction(child, this)
+ override def toString: String = s"SUM(DISTINCT $child)"
+ override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this)
- override def asPartial = {
+ override def asPartial: SplitEvaluation = {
val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
SplitEvaluation(
CombineSetsAndSum(partialSet.toAttribute, this),
@@ -411,11 +419,13 @@ case class SumDistinct(child: Expression)
case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
def this() = this(null, null)
- override def children = inputSet :: Nil
- override def nullable = true
- override def dataType = base.dataType
- override def toString = s"CombineAndSum($inputSet)"
- override def newInstance() = new CombineSetsAndSumFunction(inputSet, this)
+ override def children: Seq[Expression] = inputSet :: Nil
+ override def nullable: Boolean = true
+ override def dataType: DataType = base.dataType
+ override def toString: String = s"CombineAndSum($inputSet)"
+ override def newInstance(): CombineSetsAndSumFunction = {
+ new CombineSetsAndSumFunction(inputSet, this)
+ }
}
case class CombineSetsAndSumFunction(
@@ -449,9 +459,9 @@ case class CombineSetsAndSumFunction(
}
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = true
- override def dataType = child.dataType
- override def toString = s"FIRST($child)"
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"FIRST($child)"
override def asPartial: SplitEvaluation = {
val partialFirst = Alias(First(child), "PartialFirst")()
@@ -459,14 +469,14 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
First(partialFirst.toAttribute),
partialFirst :: Nil)
}
- override def newInstance() = new FirstFunction(child, this)
+ override def newInstance(): FirstFunction = new FirstFunction(child, this)
}
case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
- override def nullable = true
- override def dataType = child.dataType
- override def toString = s"LAST($child)"
+ override def references: AttributeSet = child.references
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"LAST($child)"
override def asPartial: SplitEvaluation = {
val partialLast = Alias(Last(child), "PartialLast")()
@@ -474,7 +484,7 @@ case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode
Last(partialLast.toAttribute),
partialLast :: Nil)
}
- override def newInstance() = new LastFunction(child, this)
+ override def newInstance(): LastFunction = new LastFunction(child, this)
}
case class AverageFunction(expr: Expression, base: AggregateExpression)
@@ -713,6 +723,7 @@ case class LastFunction(expr: Expression, base: AggregateExpression) extends Agg
result = input
}
- override def eval(input: Row): Any = if (result != null) expr.eval(result.asInstanceOf[Row])
- else null
+ override def eval(input: Row): Any = {
+ if (result != null) expr.eval(result.asInstanceOf[Row]) else null
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 00b0d3c683fe2..1f6526ef66c56 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -24,10 +24,10 @@ import org.apache.spark.sql.types._
case class UnaryMinus(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
- def dataType = child.dataType
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"-$child"
+ override def dataType: DataType = child.dataType
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"-$child"
lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
@@ -47,10 +47,10 @@ case class UnaryMinus(child: Expression) extends UnaryExpression {
case class Sqrt(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
- def dataType = DoubleType
- override def foldable = child.foldable
- def nullable = true
- override def toString = s"SQRT($child)"
+ override def dataType: DataType = DoubleType
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = true
+ override def toString: String = s"SQRT($child)"
lazy val numeric = child.dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
@@ -74,14 +74,14 @@ abstract class BinaryArithmetic extends BinaryExpression {
type EvaluatedType = Any
- def nullable = left.nullable || right.nullable
+ def nullable: Boolean = left.nullable || right.nullable
override lazy val resolved =
left.resolved && right.resolved &&
left.dataType == right.dataType &&
!DecimalType.isFixed(left.dataType)
- def dataType = {
+ def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this,
s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
@@ -108,7 +108,7 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "+"
+ override def symbol: String = "+"
lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
@@ -131,7 +131,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
}
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "-"
+ override def symbol: String = "-"
lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
@@ -154,7 +154,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
}
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "*"
+ override def symbol: String = "*"
lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
@@ -177,9 +177,9 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
}
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "/"
+ override def symbol: String = "/"
- override def nullable = true
+ override def nullable: Boolean = true
lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
@@ -203,9 +203,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "%"
+ override def symbol: String = "%"
- override def nullable = true
+ override def nullable: Boolean = true
lazy val integral = dataType match {
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
@@ -232,7 +232,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
* A function that calculates bitwise and(&) of two numbers.
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "&"
+ override def symbol: String = "&"
lazy val and: (Any, Any) => Any = dataType match {
case ByteType =>
@@ -253,7 +253,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
* A function that calculates bitwise or(|) of two numbers.
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "|"
+ override def symbol: String = "|"
lazy val or: (Any, Any) => Any = dataType match {
case ByteType =>
@@ -274,7 +274,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
* A function that calculates bitwise xor(^) of two numbers.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "^"
+ override def symbol: String = "^"
lazy val xor: (Any, Any) => Any = dataType match {
case ByteType =>
@@ -297,10 +297,10 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
case class BitwiseNot(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
- def dataType = child.dataType
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"~$child"
+ override def dataType: DataType = child.dataType
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"~$child"
lazy val not: (Any) => Any = dataType match {
case ByteType =>
@@ -327,17 +327,17 @@ case class BitwiseNot(child: Expression) extends UnaryExpression {
case class MaxOf(left: Expression, right: Expression) extends Expression {
type EvaluatedType = Any
- override def foldable = left.foldable && right.foldable
+ override def foldable: Boolean = left.foldable && right.foldable
- override def nullable = left.nullable && right.nullable
+ override def nullable: Boolean = left.nullable && right.nullable
- override def children = left :: right :: Nil
+ override def children: Seq[Expression] = left :: right :: Nil
override lazy val resolved =
left.resolved && right.resolved &&
left.dataType == right.dataType
- override def dataType = {
+ override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this,
s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
@@ -366,7 +366,7 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
}
}
- override def toString = s"MaxOf($left, $right)"
+ override def toString: String = s"MaxOf($left, $right)"
}
/**
@@ -375,10 +375,10 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
case class Abs(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
- def dataType = child.dataType
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"Abs($child)"
+ override def dataType: DataType = child.dataType
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"Abs($child)"
lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index e48b8cde20eda..d1abf3c0b64a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -91,7 +91,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val startTime = System.nanoTime()
val result = create(in)
val endTime = System.nanoTime()
- def timeMs = (endTime - startTime).toDouble / 1000000
+ def timeMs: Double = (endTime - startTime).toDouble / 1000000
logInfo(s"Code generated expression $in in $timeMs ms")
result
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index 68051a2a2007e..3fd78db297462 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -27,12 +27,12 @@ import org.apache.spark.sql.types._
case class GetItem(child: Expression, ordinal: Expression) extends Expression {
type EvaluatedType = Any
- val children = child :: ordinal :: Nil
+ val children: Seq[Expression] = child :: ordinal :: Nil
/** `Null` is returned for invalid ordinals. */
- override def nullable = true
- override def foldable = child.foldable && ordinal.foldable
+ override def nullable: Boolean = true
+ override def foldable: Boolean = child.foldable && ordinal.foldable
- def dataType = child.dataType match {
+ override def dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case MapType(_, vt, _) => vt
}
@@ -40,7 +40,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
childrenResolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
- override def toString = s"$child[$ordinal]"
+ override def toString: String = s"$child[$ordinal]"
override def eval(input: Row): Any = {
val value = child.eval(input)
@@ -75,8 +75,8 @@ trait GetField extends UnaryExpression {
self: Product =>
type EvaluatedType = Any
- override def foldable = child.foldable
- override def toString = s"$child.${field.name}"
+ override def foldable: Boolean = child.foldable
+ override def toString: String = s"$child.${field.name}"
def field: StructField
}
@@ -86,8 +86,8 @@ trait GetField extends UnaryExpression {
*/
case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {
- def dataType = field.dataType
- override def nullable = child.nullable || field.nullable
+ override def dataType: DataType = field.dataType
+ override def nullable: Boolean = child.nullable || field.nullable
override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
@@ -101,8 +101,8 @@ case class StructGetField(child: Expression, field: StructField, ordinal: Int) e
case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
extends GetField {
- def dataType = ArrayType(field.dataType, containsNull)
- override def nullable = child.nullable
+ override def dataType: DataType = ArrayType(field.dataType, containsNull)
+ override def nullable: Boolean = child.nullable
override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
@@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co
case class CreateArray(children: Seq[Expression]) extends Expression {
override type EvaluatedType = Any
- override def foldable = !children.exists(!_.foldable)
+ override def foldable: Boolean = !children.exists(!_.foldable)
lazy val childTypes = children.map(_.dataType).distinct
@@ -140,5 +140,5 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
children.map(_.eval(input))
}
- override def toString = s"Array(${children.mkString(",")})"
+ override def toString: String = s"Array(${children.mkString(",")})"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index 83d8c1d42bca4..adb94df7d1c7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -24,9 +24,9 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
override type EvaluatedType = Any
override def dataType: DataType = LongType
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"UnscaledValue($child)"
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"UnscaledValue($child)"
override def eval(input: Row): Any = {
val childResult = child.eval(input)
@@ -43,9 +43,9 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
override type EvaluatedType = Decimal
override def dataType: DataType = DecimalType(precision, scale)
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"MakeDecimal($child,$precision,$scale)"
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"MakeDecimal($child,$precision,$scale)"
override def eval(input: Row): Decimal = {
val childResult = child.eval(input)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 0983d274def3f..860b72fad38b3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -45,7 +45,7 @@ abstract class Generator extends Expression {
override lazy val dataType =
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
- override def nullable = false
+ override def nullable: Boolean = false
/**
* Should be overridden by specific generators. Called only once for each instance to ensure
@@ -89,7 +89,7 @@ case class UserDefinedGenerator(
function(inputRow(input))
}
- override def toString = s"UserDefinedGenerator(${children.mkString(",")})"
+ override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
}
/**
@@ -130,5 +130,5 @@ case class Explode(attributeNames: Seq[String], child: Expression)
}
}
- override def toString() = s"explode($child)"
+ override def toString: String = s"explode($child)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 9ff66563c8164..19f3fc9c2291a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -64,14 +64,13 @@ object IntegerLiteral {
case class Literal(value: Any, dataType: DataType) extends LeafExpression {
- override def foldable = true
- def nullable = value == null
+ override def foldable: Boolean = true
+ override def nullable: Boolean = value == null
-
- override def toString = if (value != null) value.toString else "null"
+ override def toString: String = if (value != null) value.toString else "null"
type EvaluatedType = Any
- override def eval(input: Row):Any = value
+ override def eval(input: Row): Any = value
}
// TODO: Specialize
@@ -79,9 +78,9 @@ case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean
extends LeafExpression {
type EvaluatedType = Any
- def update(expression: Expression, input: Row) = {
+ def update(expression: Expression, input: Row): Unit = {
value = expression.eval(input)
}
- override def eval(input: Row) = value
+ override def eval(input: Row): Any = value
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 17f7f9fe51376..bcbcbeb31c7b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.trees.LeafNode
import org.apache.spark.sql.types._
object NamedExpression {
private val curId = new java.util.concurrent.atomic.AtomicLong()
- def newExprId = ExprId(curId.getAndIncrement())
+ def newExprId: ExprId = ExprId(curId.getAndIncrement())
def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType)
}
@@ -41,6 +42,13 @@ abstract class NamedExpression extends Expression {
def name: String
def exprId: ExprId
+ /**
+ * Returns a dot separated fully qualified name for this attribute. Given that there can be
+ * multiple qualifiers, it is possible that there are other possible way to refer to this
+ * attribute.
+ */
+ def qualifiedName: String = (qualifiers.headOption.toSeq :+ name).mkString(".")
+
/**
* All possible qualifiers for the expression.
*
@@ -72,13 +80,13 @@ abstract class NamedExpression extends Expression {
abstract class Attribute extends NamedExpression {
self: Product =>
- override def references = AttributeSet(this)
+ override def references: AttributeSet = AttributeSet(this)
def withNullability(newNullability: Boolean): Attribute
def withQualifiers(newQualifiers: Seq[String]): Attribute
def withName(newName: String): Attribute
- def toAttribute = this
+ def toAttribute: Attribute = this
def newInstance(): Attribute
}
@@ -95,25 +103,30 @@ abstract class Attribute extends NamedExpression {
* @param name the name to be associated with the result of computing [[child]].
* @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this
* alias. Auto-assigned if left blank.
+ * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's.
*/
-case class Alias(child: Expression, name: String)
- (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
+case class Alias(child: Expression, name: String)(
+ val exprId: ExprId = NamedExpression.newExprId,
+ val qualifiers: Seq[String] = Nil,
+ val explicitMetadata: Option[Metadata] = None)
extends NamedExpression with trees.UnaryNode[Expression] {
override type EvaluatedType = Any
- override def eval(input: Row) = child.eval(input)
+ override def eval(input: Row): Any = child.eval(input)
- override def dataType = child.dataType
- override def nullable = child.nullable
+ override def dataType: DataType = child.dataType
+ override def nullable: Boolean = child.nullable
override def metadata: Metadata = {
- child match {
- case named: NamedExpression => named.metadata
- case _ => Metadata.empty
+ explicitMetadata.getOrElse {
+ child match {
+ case named: NamedExpression => named.metadata
+ case _ => Metadata.empty
+ }
}
}
- override def toAttribute = {
+ override def toAttribute: Attribute = {
if (resolved) {
AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers)
} else {
@@ -123,11 +136,14 @@ case class Alias(child: Expression, name: String)
override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix"
- override protected final def otherCopyArgs = exprId :: qualifiers :: Nil
+ override protected final def otherCopyArgs: Seq[AnyRef] = {
+ exprId :: qualifiers :: explicitMetadata :: Nil
+ }
override def equals(other: Any): Boolean = other match {
case a: Alias =>
- name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers
+ name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers &&
+ explicitMetadata == a.explicitMetadata
case _ => false
}
}
@@ -153,7 +169,7 @@ case class AttributeReference(
val exprId: ExprId = NamedExpression.newExprId,
val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] {
- override def equals(other: Any) = other match {
+ override def equals(other: Any): Boolean = other match {
case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType
case _ => false
}
@@ -167,7 +183,7 @@ case class AttributeReference(
h
}
- override def newInstance() =
+ override def newInstance(): AttributeReference =
AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers)
/**
@@ -192,7 +208,7 @@ case class AttributeReference(
/**
* Returns a copy of this [[AttributeReference]] with new qualifiers.
*/
- override def withQualifiers(newQualifiers: Seq[String]) = {
+ override def withQualifiers(newQualifiers: Seq[String]): AttributeReference = {
if (newQualifiers.toSet == qualifiers.toSet) {
this
} else {
@@ -214,20 +230,22 @@ case class AttributeReference(
case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
type EvaluatedType = Any
- override def toString = name
-
- override def withNullability(newNullability: Boolean): Attribute = ???
- override def newInstance(): Attribute = ???
- override def withQualifiers(newQualifiers: Seq[String]): Attribute = ???
- override def withName(newName: String): Attribute = ???
- override def qualifiers: Seq[String] = ???
- override def exprId: ExprId = ???
- override def eval(input: Row): EvaluatedType = ???
- override def nullable: Boolean = ???
+ override def toString: String = name
+
+ override def withNullability(newNullability: Boolean): Attribute =
+ throw new UnsupportedOperationException
+ override def newInstance(): Attribute = throw new UnsupportedOperationException
+ override def withQualifiers(newQualifiers: Seq[String]): Attribute =
+ throw new UnsupportedOperationException
+ override def withName(newName: String): Attribute = throw new UnsupportedOperationException
+ override def qualifiers: Seq[String] = throw new UnsupportedOperationException
+ override def exprId: ExprId = throw new UnsupportedOperationException
+ override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException
+ override def nullable: Boolean = throw new UnsupportedOperationException
override def dataType: DataType = NullType
}
object VirtualColumn {
- val groupingIdName = "grouping__id"
- def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)()
+ val groupingIdName: String = "grouping__id"
+ def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 08b982bc671e7..d1f3d4f4ee9ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -19,22 +19,23 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.types.DataType
case class Coalesce(children: Seq[Expression]) extends Expression {
type EvaluatedType = Any
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
- def nullable = !children.exists(!_.nullable)
+ override def nullable: Boolean = !children.exists(!_.nullable)
// Coalesce is foldable if all children are foldable.
- override def foldable = !children.exists(!_.foldable)
+ override def foldable: Boolean = !children.exists(!_.foldable)
// Only resolved if all the children are of the same type.
override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1)
- override def toString = s"Coalesce(${children.mkString(",")})"
+ override def toString: String = s"Coalesce(${children.mkString(",")})"
- def dataType = if (resolved) {
+ def dataType: DataType = if (resolved) {
children.head.dataType
} else {
val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ")
@@ -54,20 +55,20 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
- override def foldable = child.foldable
- def nullable = false
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = false
override def eval(input: Row): Any = {
child.eval(input) == null
}
- override def toString = s"IS NULL $child"
+ override def toString: String = s"IS NULL $child"
}
case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
- override def foldable = child.foldable
- def nullable = false
- override def toString = s"IS NOT NULL $child"
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = false
+ override def toString: String = s"IS NOT NULL $child"
override def eval(input: Row): Any = {
child.eval(input) != null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 0024ef92c0452..7e47cb3fffe12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.{BinaryType, BooleanType, NativeType}
+import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType}
object InterpretedPredicate {
def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
@@ -34,7 +34,7 @@ object InterpretedPredicate {
trait Predicate extends Expression {
self: Product =>
- def dataType = BooleanType
+ override def dataType: DataType = BooleanType
type EvaluatedType = Any
}
@@ -72,13 +72,13 @@ trait PredicateHelper {
abstract class BinaryPredicate extends BinaryExpression with Predicate {
self: Product =>
- def nullable = left.nullable || right.nullable
+ override def nullable: Boolean = left.nullable || right.nullable
}
case class Not(child: Expression) extends UnaryExpression with Predicate {
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"NOT $child"
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"NOT $child"
override def eval(input: Row): Any = {
child.eval(input) match {
@@ -92,10 +92,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate {
* Evaluates to `true` if `list` contains `value`.
*/
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
- def children = value +: list
+ override def children: Seq[Expression] = value +: list
- def nullable = true // TODO: Figure out correct nullability semantics of IN.
- override def toString = s"$value IN ${list.mkString("(", ",", ")")}"
+ override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
+ override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
override def eval(input: Row): Any = {
val evaluatedValue = value.eval(input)
@@ -110,10 +110,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
case class InSet(value: Expression, hset: Set[Any])
extends Predicate {
- def children = value :: Nil
+ override def children: Seq[Expression] = value :: Nil
- def nullable = true // TODO: Figure out correct nullability semantics of IN.
- override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}"
+ override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
+ override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}"
override def eval(input: Row): Any = {
hset.contains(value.eval(input))
@@ -121,7 +121,7 @@ case class InSet(value: Expression, hset: Set[Any])
}
case class And(left: Expression, right: Expression) extends BinaryPredicate {
- def symbol = "&&"
+ override def symbol: String = "&&"
override def eval(input: Row): Any = {
val l = left.eval(input)
@@ -143,7 +143,7 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate {
}
case class Or(left: Expression, right: Expression) extends BinaryPredicate {
- def symbol = "||"
+ override def symbol: String = "||"
override def eval(input: Row): Any = {
val l = left.eval(input)
@@ -169,7 +169,8 @@ abstract class BinaryComparison extends BinaryPredicate {
}
case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = "="
+ override def symbol: String = "="
+
override def eval(input: Row): Any = {
val l = left.eval(input)
if (l == null) {
@@ -185,8 +186,10 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
}
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = "<=>"
- override def nullable = false
+ override def symbol: String = "<=>"
+
+ override def nullable: Boolean = false
+
override def eval(input: Row): Any = {
val l = left.eval(input)
val r = right.eval(input)
@@ -201,9 +204,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
}
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = "<"
+ override def symbol: String = "<"
- lazy val ordering = {
+ lazy val ordering: Ordering[Any] = {
if (left.dataType != right.dataType) {
throw new TreeNodeException(this,
s"Types do not match ${left.dataType} != ${right.dataType}")
@@ -216,7 +219,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
override def eval(input: Row): Any = {
val evalE1 = left.eval(input)
- if(evalE1 == null) {
+ if (evalE1 == null) {
null
} else {
val evalE2 = right.eval(input)
@@ -230,9 +233,9 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
}
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = "<="
+ override def symbol: String = "<="
- lazy val ordering = {
+ lazy val ordering: Ordering[Any] = {
if (left.dataType != right.dataType) {
throw new TreeNodeException(this,
s"Types do not match ${left.dataType} != ${right.dataType}")
@@ -245,7 +248,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
override def eval(input: Row): Any = {
val evalE1 = left.eval(input)
- if(evalE1 == null) {
+ if (evalE1 == null) {
null
} else {
val evalE2 = right.eval(input)
@@ -259,9 +262,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
}
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = ">"
+ override def symbol: String = ">"
- lazy val ordering = {
+ lazy val ordering: Ordering[Any] = {
if (left.dataType != right.dataType) {
throw new TreeNodeException(this,
s"Types do not match ${left.dataType} != ${right.dataType}")
@@ -288,9 +291,9 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
}
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = ">="
+ override def symbol: String = ">="
- lazy val ordering = {
+ lazy val ordering: Ordering[Any] = {
if (left.dataType != right.dataType) {
throw new TreeNodeException(this,
s"Types do not match ${left.dataType} != ${right.dataType}")
@@ -303,7 +306,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar
override def eval(input: Row): Any = {
val evalE1 = left.eval(input)
- if(evalE1 == null) {
+ if (evalE1 == null) {
null
} else {
val evalE2 = right.eval(input)
@@ -317,13 +320,13 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar
}
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
- extends Expression {
+ extends Expression {
- def children = predicate :: trueValue :: falseValue :: Nil
- override def nullable = trueValue.nullable || falseValue.nullable
+ override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil
+ override def nullable: Boolean = trueValue.nullable || falseValue.nullable
override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType
- def dataType = {
+ override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(
this,
@@ -342,7 +345,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
}
- override def toString = s"if ($predicate) $trueValue else $falseValue"
+ override def toString: String = s"if ($predicate) $trueValue else $falseValue"
}
// scalastyle:off
@@ -362,9 +365,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
// scalastyle:on
case class CaseWhen(branches: Seq[Expression]) extends Expression {
type EvaluatedType = Any
- def children = branches
- def dataType = {
+ override def children: Seq[Expression] = branches
+
+ override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
}
@@ -379,12 +383,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
@transient private[this] lazy val elseValue =
if (branches.length % 2 == 0) None else Option(branches.last)
- override def nullable = {
+ override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
}
- override lazy val resolved = {
+ override lazy val resolved: Boolean = {
if (!childrenResolved) {
false
} else {
@@ -415,7 +419,7 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
res
}
- override def toString = {
+ override def toString: String = {
"CASE" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
case Seq(elseValue) => s" ELSE $elseValue"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index f03d6f71a9fae..8bba26bc4cf7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -44,8 +44,8 @@ trait MutableRow extends Row {
*/
object EmptyRow extends Row {
override def apply(i: Int): Any = throw new UnsupportedOperationException
- override def toSeq = Seq.empty
- override def length = 0
+ override def toSeq: Seq[Any] = Seq.empty
+ override def length: Int = 0
override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException
override def getInt(i: Int): Int = throw new UnsupportedOperationException
override def getLong(i: Int): Long = throw new UnsupportedOperationException
@@ -56,7 +56,7 @@ object EmptyRow extends Row {
override def getByte(i: Int): Byte = throw new UnsupportedOperationException
override def getString(i: Int): String = throw new UnsupportedOperationException
override def getAs[T](i: Int): T = throw new UnsupportedOperationException
- def copy() = this
+ override def copy(): Row = this
}
/**
@@ -70,13 +70,13 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
def this(size: Int) = this(new Array[Any](size))
- override def toSeq = values.toSeq
+ override def toSeq: Seq[Any] = values.toSeq
- override def length = values.length
+ override def length: Int = values.length
- override def apply(i: Int) = values(i)
+ override def apply(i: Int): Any = values(i)
- override def isNullAt(i: Int) = values(i) == null
+ override def isNullAt(i: Int): Boolean = values(i) == null
override def getInt(i: Int): Int = {
if (values(i) == null) sys.error("Failed to check null bit for primitive int value.")
@@ -167,7 +167,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
case _ => false
}
- def copy() = this
+ override def copy(): Row = this
}
class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
@@ -194,7 +194,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value }
- override def copy() = new GenericRow(values.clone())
+ override def copy(): Row = new GenericRow(values.clone())
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index 3a5bdca1f07c3..35faa00782e80 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -26,17 +26,17 @@ import org.apache.spark.util.collection.OpenHashSet
case class NewSet(elementType: DataType) extends LeafExpression {
type EvaluatedType = Any
- def nullable = false
+ override def nullable: Boolean = false
// We are currently only using these Expressions internally for aggregation. However, if we ever
// expose these to users we'll want to create a proper type instead of hijacking ArrayType.
- def dataType = ArrayType(elementType)
+ override def dataType: DataType = ArrayType(elementType)
- def eval(input: Row): Any = {
+ override def eval(input: Row): Any = {
new OpenHashSet[Any]()
}
- override def toString = s"new Set($dataType)"
+ override def toString: String = s"new Set($dataType)"
}
/**
@@ -46,12 +46,13 @@ case class NewSet(elementType: DataType) extends LeafExpression {
case class AddItemToSet(item: Expression, set: Expression) extends Expression {
type EvaluatedType = Any
- def children = item :: set :: Nil
+ override def children: Seq[Expression] = item :: set :: Nil
- def nullable = set.nullable
+ override def nullable: Boolean = set.nullable
- def dataType = set.dataType
- def eval(input: Row): Any = {
+ override def dataType: DataType = set.dataType
+
+ override def eval(input: Row): Any = {
val itemEval = item.eval(input)
val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
@@ -67,7 +68,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
}
}
- override def toString = s"$set += $item"
+ override def toString: String = s"$set += $item"
}
/**
@@ -77,13 +78,13 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
type EvaluatedType = Any
- def nullable = left.nullable || right.nullable
+ override def nullable: Boolean = left.nullable || right.nullable
- def dataType = left.dataType
+ override def dataType: DataType = left.dataType
- def symbol = "++="
+ override def symbol: String = "++="
- def eval(input: Row): Any = {
+ override def eval(input: Row): Any = {
val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]]
if(leftEval != null) {
val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]]
@@ -109,16 +110,16 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
case class CountSet(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
- def nullable = child.nullable
+ override def nullable: Boolean = child.nullable
- def dataType = LongType
+ override def dataType: DataType = LongType
- def eval(input: Row): Any = {
+ override def eval(input: Row): Any = {
val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]]
if (childEval != null) {
childEval.size.toLong
}
}
- override def toString = s"$child.count()"
+ override def toString: String = s"$child.count()"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index f85ee0a9bb6d8..3cdca4e9dd2d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -33,8 +33,8 @@ trait StringRegexExpression {
def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean
- def nullable: Boolean = left.nullable || right.nullable
- def dataType: DataType = BooleanType
+ override def nullable: Boolean = left.nullable || right.nullable
+ override def dataType: DataType = BooleanType
// try cache the pattern for Literal
private lazy val cache: Pattern = right match {
@@ -98,11 +98,11 @@ trait CaseConversionExpression {
case class Like(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
- def symbol = "LIKE"
+ override def symbol: String = "LIKE"
// replace the _ with .{1} exactly match 1 time of any character
// replace the % with .*, match 0 or more times with any character
- override def escape(v: String) =
+ override def escape(v: String): String =
if (!v.isEmpty) {
"(?s)" + (' ' +: v.init).zip(v).flatMap {
case (prev, '\\') => ""
@@ -129,7 +129,7 @@ case class Like(left: Expression, right: Expression)
case class RLike(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
- def symbol = "RLIKE"
+ override def symbol: String = "RLIKE"
override def escape(v: String): String = v
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
}
@@ -141,7 +141,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
override def convert(v: String): String = v.toUpperCase()
- override def toString() = s"Upper($child)"
+ override def toString: String = s"Upper($child)"
}
/**
@@ -151,7 +151,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
override def convert(v: String): String = v.toLowerCase()
- override def toString() = s"Lower($child)"
+ override def toString: String = s"Lower($child)"
}
/** A base trait for functions that compare two strings, returning a boolean. */
@@ -160,7 +160,7 @@ trait StringComparison {
type EvaluatedType = Any
- def nullable: Boolean = left.nullable || right.nullable
+ override def nullable: Boolean = left.nullable || right.nullable
override def dataType: DataType = BooleanType
def compare(l: String, r: String): Boolean
@@ -175,9 +175,9 @@ trait StringComparison {
}
}
- def symbol: String = nodeName
+ override def symbol: String = nodeName
- override def toString() = s"$nodeName($left, $right)"
+ override def toString: String = s"$nodeName($left, $right)"
}
/**
@@ -185,7 +185,7 @@ trait StringComparison {
*/
case class Contains(left: Expression, right: Expression)
extends BinaryExpression with StringComparison {
- override def compare(l: String, r: String) = l.contains(r)
+ override def compare(l: String, r: String): Boolean = l.contains(r)
}
/**
@@ -193,7 +193,7 @@ case class Contains(left: Expression, right: Expression)
*/
case class StartsWith(left: Expression, right: Expression)
extends BinaryExpression with StringComparison {
- def compare(l: String, r: String) = l.startsWith(r)
+ override def compare(l: String, r: String): Boolean = l.startsWith(r)
}
/**
@@ -201,7 +201,7 @@ case class StartsWith(left: Expression, right: Expression)
*/
case class EndsWith(left: Expression, right: Expression)
extends BinaryExpression with StringComparison {
- def compare(l: String, r: String) = l.endsWith(r)
+ override def compare(l: String, r: String): Boolean = l.endsWith(r)
}
/**
@@ -212,17 +212,17 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
type EvaluatedType = Any
- override def foldable = str.foldable && pos.foldable && len.foldable
+ override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
- def nullable: Boolean = str.nullable || pos.nullable || len.nullable
- def dataType: DataType = {
+ override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
+ override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved")
}
if (str.dataType == BinaryType) str.dataType else StringType
}
- override def children = str :: pos :: len :: Nil
+ override def children: Seq[Expression] = str :: pos :: len :: Nil
@inline
def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int)
@@ -267,7 +267,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
}
}
- override def toString = len match {
+ override def toString: String = len match {
+ // TODO: This is broken because max is not an integer value.
case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)"
case _ => s"SUBSTR($str, $pos, $len)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1a75fcf3545bd..c23d3b61887c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.immutable.HashSet
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
@@ -32,6 +33,9 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan]
object DefaultOptimizer extends Optimizer {
val batches =
+ // SubQueries are only needed for analysis and can be removed before execution.
+ Batch("Remove SubQueries", FixedPoint(100),
+ EliminateSubQueries) ::
Batch("Combine Limits", FixedPoint(100),
CombineLimits) ::
Batch("ConstantFolding", FixedPoint(100),
@@ -137,7 +141,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
/** Applies a projection only when the child is producing unnecessary attributes */
- def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences)
+ def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences)
Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index b4c445b3badf1..9c8c643f7d17a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -91,16 +91,18 @@ object PhysicalOperation extends PredicateHelper {
(None, Nil, other, Map.empty)
}
- def collectAliases(fields: Seq[Expression]) = fields.collect {
+ def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect {
case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child
}.toMap
- def substitute(aliases: Map[Attribute, Expression])(expr: Expression) = expr.transform {
- case a @ Alias(ref: AttributeReference, name) =>
- aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a)
+ def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
+ expr.transform {
+ case a @ Alias(ref: AttributeReference, name) =>
+ aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a)
- case a: AttributeReference =>
- aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a)
+ case a: AttributeReference =>
+ aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a)
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 48191f31198f3..02f7c26a8ab6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -71,7 +71,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = {
var changed = false
- @inline def transformExpressionDown(e: Expression) = {
+ @inline def transformExpressionDown(e: Expression): Expression = {
val newE = e.transformDown(rule)
if (newE.fastEquals(e)) {
e
@@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionDown(e)
case Some(e: Expression) => Some(transformExpressionDown(e))
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionDown(e)
case other => other
@@ -103,7 +104,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = {
var changed = false
- @inline def transformExpressionUp(e: Expression) = {
+ @inline def transformExpressionUp(e: Expression): Expression = {
val newE = e.transformUp(rule)
if (newE.fastEquals(e)) {
e
@@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionUp(e)
case Some(e: Expression) => Some(transformExpressionUp(e))
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionUp(e)
case other => other
@@ -163,5 +165,5 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
*/
protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else ""
- override def simpleString = statePrefix + super.simpleString
+ override def simpleString: String = statePrefix + super.simpleString
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 8c4f09b58a4f2..b01a61d7bf8d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver}
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -73,12 +73,16 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* can do better should override this function.
*/
def sameResult(plan: LogicalPlan): Boolean = {
- plan.getClass == this.getClass &&
- plan.children.size == children.size && {
- logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]")
- cleanArgs == plan.cleanArgs
+ val cleanLeft = EliminateSubQueries(this)
+ val cleanRight = EliminateSubQueries(plan)
+
+ cleanLeft.getClass == cleanRight.getClass &&
+ cleanLeft.children.size == cleanRight.children.size && {
+ logDebug(
+ s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]")
+ cleanRight.cleanArgs == cleanLeft.cleanArgs
} &&
- (plan.children, children).zipped.forall(_ sameResult _)
+ (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _)
}
/** Args that have cleaned such that differences in expression id should not affect equality */
@@ -208,8 +212,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// More than one match.
case ambiguousReferences =>
+ val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ")
throw new AnalysisException(
- s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
+ s"Reference '$name' is ambiguous, could be: $referenceNames.")
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 384fe53a68362..4d9e41a2b5d85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
- def output = projectList.map(_.toAttribute)
+ override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override lazy val resolved: Boolean = {
val containsAggregatesOrGenerators = projectList.exists ( _.collect {
@@ -66,19 +66,19 @@ case class Generate(
}
}
- override def output =
+ override def output: Seq[Attribute] =
if (join) child.output ++ generatorOutput else generatorOutput
}
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
}
case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
// TODO: These aren't really the same attributes as nullability etc might change.
- override def output = left.output
+ override def output: Seq[Attribute] = left.output
- override lazy val resolved =
+ override lazy val resolved: Boolean =
childrenResolved &&
!left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType }
@@ -94,7 +94,7 @@ case class Join(
joinType: JoinType,
condition: Option[Expression]) extends BinaryNode {
- override def output = {
+ override def output: Seq[Attribute] = {
joinType match {
case LeftSemi =>
left.output
@@ -109,7 +109,7 @@ case class Join(
}
}
- def selfJoinResolved = left.outputSet.intersect(right.outputSet).isEmpty
+ private def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
// Joins are only resolved if they don't introduce ambiguious expression ids.
override lazy val resolved: Boolean = {
@@ -118,7 +118,7 @@ case class Join(
}
case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
- def output = left.output
+ override def output: Seq[Attribute] = left.output
}
case class InsertIntoTable(
@@ -128,10 +128,10 @@ case class InsertIntoTable(
overwrite: Boolean)
extends LogicalPlan {
- override def children = child :: Nil
- override def output = child.output
+ override def children: Seq[LogicalPlan] = child :: Nil
+ override def output: Seq[Attribute] = child.output
- override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
+ override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall {
case (childAttr, tableAttr) =>
DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
}
@@ -143,14 +143,14 @@ case class CreateTableAsSelect[T](
child: LogicalPlan,
allowExisting: Boolean,
desc: Option[T] = None) extends UnaryNode {
- override def output = Seq.empty[Attribute]
- override lazy val resolved = databaseName != None && childrenResolved
+ override def output: Seq[Attribute] = Seq.empty[Attribute]
+ override lazy val resolved: Boolean = databaseName != None && childrenResolved
}
case class WriteToFile(
path: String,
child: LogicalPlan) extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
}
/**
@@ -163,7 +163,7 @@ case class Sort(
order: Seq[SortOrder],
global: Boolean,
child: LogicalPlan) extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
}
case class Aggregate(
@@ -172,7 +172,7 @@ case class Aggregate(
child: LogicalPlan)
extends UnaryNode {
- override def output = aggregateExpressions.map(_.toAttribute)
+ override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}
/**
@@ -199,7 +199,7 @@ trait GroupingAnalytics extends UnaryNode {
def groupByExprs: Seq[Expression]
def aggregations: Seq[NamedExpression]
- override def output = aggregations.map(_.toAttribute)
+ override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
}
/**
@@ -264,7 +264,7 @@ case class Rollup(
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
override lazy val statistics: Statistics = {
val limit = limitExpr.eval(null).asInstanceOf[Int]
@@ -274,21 +274,21 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
- override def output = child.output.map(_.withQualifiers(alias :: Nil))
+ override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
}
case class Distinct(child: LogicalPlan) extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
}
case object NoRelation extends LeafNode {
- override def output = Nil
+ override def output: Seq[Attribute] = Nil
/**
* Computes [[Statistics]] for this plan. The default implementation assumes the output
@@ -301,5 +301,5 @@ case object NoRelation extends LeafNode {
}
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
- override def output = left.output
+ override def output: Seq[Attribute] = left.output
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
index 72b0c5c8e7a26..e737418d9c3bc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder}
/**
* Performs a physical redistribution of the data. Used when the consumer of the query
@@ -26,14 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
abstract class RedistributeData extends UnaryNode {
self: Product =>
- def output = child.output
+ override def output: Seq[Attribute] = child.output
}
case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan)
- extends RedistributeData {
-}
+ extends RedistributeData
case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan)
- extends RedistributeData {
-}
-
+ extends RedistributeData
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 3c3d7a3119064..288c11f69fe22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder}
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{DataType, IntegerType}
/**
* Specifies how tuples that share common expressions will be distributed when a query is executed
@@ -72,7 +72,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
"a single partition.")
// TODO: This is not really valid...
- def clustering = ordering.map(_.child).toSet
+ def clustering: Set[Expression] = ordering.map(_.child).toSet
}
sealed trait Partitioning {
@@ -113,7 +113,7 @@ case object SinglePartition extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def compatibleWith(other: Partitioning) = other match {
+ override def compatibleWith(other: Partitioning): Boolean = other match {
case SinglePartition => true
case _ => false
}
@@ -124,7 +124,7 @@ case object BroadcastPartitioning extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def compatibleWith(other: Partitioning) = other match {
+ override def compatibleWith(other: Partitioning): Boolean = other match {
case SinglePartition => true
case _ => false
}
@@ -139,9 +139,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
extends Expression
with Partitioning {
- override def children = expressions
- override def nullable = false
- override def dataType = IntegerType
+ override def children: Seq[Expression] = expressions
+ override def nullable: Boolean = false
+ override def dataType: DataType = IntegerType
private[this] lazy val clusteringSet = expressions.toSet
@@ -152,7 +152,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
- override def compatibleWith(other: Partitioning) = other match {
+ override def compatibleWith(other: Partitioning): Boolean = other match {
case BroadcastPartitioning => true
case h: HashPartitioning if h == this => true
case _ => false
@@ -178,9 +178,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
extends Expression
with Partitioning {
- override def children = ordering
- override def nullable = false
- override def dataType = IntegerType
+ override def children: Seq[SortOrder] = ordering
+ override def nullable: Boolean = false
+ override def dataType: DataType = IntegerType
private[this] lazy val clusteringSet = ordering.map(_.child).toSet
@@ -194,7 +194,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}
- override def compatibleWith(other: Partitioning) = other match {
+ override def compatibleWith(other: Partitioning): Boolean = other match {
case BroadcastPartitioning => true
case r: RangePartitioning if r == this => true
case _ => false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index f84ffe4e176cc..a2df51e598a2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.types.DataType
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int)
@@ -35,12 +36,12 @@ object CurrentOrigin {
override def initialValue: Origin = Origin()
}
- def get = value.get()
- def set(o: Origin) = value.set(o)
+ def get: Origin = value.get()
+ def set(o: Origin): Unit = value.set(o)
- def reset() = value.set(Origin())
+ def reset(): Unit = value.set(Origin())
- def setPosition(line: Int, start: Int) = {
+ def setPosition(line: Int, start: Int): Unit = {
value.set(
value.get.copy(line = Some(line), startPosition = Some(start)))
}
@@ -56,7 +57,7 @@ object CurrentOrigin {
abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
self: BaseType with Product =>
- val origin = CurrentOrigin.get
+ val origin: Origin = CurrentOrigin.get
/** Returns a Seq of the children of this node */
def children: Seq[BaseType]
@@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg)
}
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
@@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg)
}
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
@@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
+ val defaultCtor =
+ getClass.getConstructors
+ .find(_.getParameterTypes.size != 0)
+ .headOption
+ .getOrElse(sys.error(s"No valid constructor for $nodeName"))
+
try {
CurrentOrigin.withOrigin(origin) {
// Skip no-arg constructors that are just there for kryo.
- val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
if (otherCopyArgs.isEmpty) {
defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
} else {
@@ -320,18 +328,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
} catch {
case e: java.lang.IllegalArgumentException =>
throw new TreeNodeException(
- this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? "
- + s"Exception message: ${e.getMessage}.")
+ this,
+ s"""
+ |Failed to copy node.
+ |Is otherCopyArgs specified correctly for $nodeName.
+ |Exception message: ${e.getMessage}
+ |ctor: $defaultCtor?
+ |args: ${newArgs.mkString(", ")}
+ """.stripMargin)
}
}
/** Returns the name of this type of TreeNode. Defaults to the class name. */
- def nodeName = getClass.getSimpleName
+ def nodeName: String = getClass.getSimpleName
/**
* The arguments that should be included in the arg string. Defaults to the `productIterator`.
*/
- protected def stringArgs = productIterator
+ protected def stringArgs: Iterator[Any] = productIterator
/** Returns a string representing the arguments to this node, minus any children */
def argString: String = productIterator.flatMap {
@@ -343,18 +357,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
}.mkString(", ")
/** String representation of this node without any children */
- def simpleString = s"$nodeName $argString".trim
+ def simpleString: String = s"$nodeName $argString".trim
override def toString: String = treeString
/** Returns a string representation of the nodes in this tree */
- def treeString = generateTreeString(0, new StringBuilder).toString
+ def treeString: String = generateTreeString(0, new StringBuilder).toString
/**
* Returns a string representation of the nodes in this tree, where each operator is numbered.
* The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees.
*/
- def numberedTreeString =
+ def numberedTreeString: String =
treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n")
/**
@@ -406,14 +420,14 @@ trait BinaryNode[BaseType <: TreeNode[BaseType]] {
def left: BaseType
def right: BaseType
- def children = Seq(left, right)
+ def children: Seq[BaseType] = Seq(left, right)
}
/**
* A [[TreeNode]] with no children.
*/
trait LeafNode[BaseType <: TreeNode[BaseType]] {
- def children = Nil
+ def children: Seq[BaseType] = Nil
}
/**
@@ -421,6 +435,5 @@ trait LeafNode[BaseType <: TreeNode[BaseType]] {
*/
trait UnaryNode[BaseType <: TreeNode[BaseType]] {
def child: BaseType
- def children = child :: Nil
+ def children: Seq[BaseType] = child :: Nil
}
-
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
index 79a8e06d4b4d4..ea6aa1850db4c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
@@ -41,11 +41,11 @@ package object trees extends Logging {
* A [[TreeNode]] companion for reference equality for Hash based Collection.
*/
class TreeNodeRef(val obj: TreeNode[_]) {
- override def equals(o: Any) = o match {
+ override def equals(o: Any): Boolean = o match {
case that: TreeNodeRef => that.obj.eq(obj)
case _ => false
}
- override def hashCode = if (obj == null) 0 else obj.hashCode
+ override def hashCode: Int = if (obj == null) 0 else obj.hashCode
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index feed50f9a2a2d..c86214a2aa944 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -23,7 +23,7 @@ import org.apache.spark.util.Utils
package object util {
- def fileToString(file: File, encoding: String = "UTF-8") = {
+ def fileToString(file: File, encoding: String = "UTF-8"): String = {
val inStream = new FileInputStream(file)
val outStream = new ByteArrayOutputStream
try {
@@ -45,7 +45,7 @@ package object util {
def resourceToString(
resource:String,
encoding: String = "UTF-8",
- classLoader: ClassLoader = Utils.getSparkClassLoader) = {
+ classLoader: ClassLoader = Utils.getSparkClassLoader): String = {
val inStream = classLoader.getResourceAsStream(resource)
val outStream = new ByteArrayOutputStream
try {
@@ -93,7 +93,7 @@ package object util {
new String(out.toByteArray)
}
- def stringOrNull(a: AnyRef) = if (a == null) null else a.toString
+ def stringOrNull(a: AnyRef): String = if (a == null) null else a.toString
def benchmark[A](f: => A): A = {
val startTime = System.nanoTime()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 359aec4a7b5ab..756cd36f05c8c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -32,9 +32,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseInsensitiveCatalog = new SimpleCatalog(false)
val caseSensitiveAnalyzer =
- new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true)
+ new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) {
+ override val extendedResolutionRules = EliminateSubQueries :: Nil
+ }
val caseInsensitiveAnalyzer =
- new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
+ new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) {
+ override val extendedResolutionRules = EliminateSubQueries :: Nil
+ }
val checkAnalysis = new CheckAnalysis
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index ec7d15f5bc4e7..3cd7adf8cab5e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
+import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField}
@@ -46,7 +47,7 @@ private[sql] object Column {
* @groupname Ungrouped Support functions for DataFrames.
*/
@Experimental
-class Column(protected[sql] val expr: Expression) {
+class Column(protected[sql] val expr: Expression) extends Logging {
def this(name: String) = this(name match {
case "*" => UnresolvedStar(None)
@@ -109,7 +110,15 @@ class Column(protected[sql] val expr: Expression) {
*
* @group expr_ops
*/
- def === (other: Any): Column = EqualTo(expr, lit(other).expr)
+ def === (other: Any): Column = {
+ val right = lit(other).expr
+ if (this.expr == right) {
+ logWarning(
+ s"Constructing trivially true equals predicate, '${this.expr} = $right'. " +
+ "Perhaps you need to use aliases.")
+ }
+ EqualTo(expr, right)
+ }
/**
* Equality test.
@@ -594,6 +603,19 @@ class Column(protected[sql] val expr: Expression) {
*/
def as(alias: Symbol): Column = Alias(expr, alias.name)()
+ /**
+ * Gives the column an alias with metadata.
+ * {{{
+ * val metadata: Metadata = ...
+ * df.select($"colA".as("colB", metadata))
+ * }}}
+ *
+ * @group expr_ops
+ */
+ def as(alias: String, metadata: Metadata): Column = {
+ Alias(expr, alias)(explicitMetadata = Some(metadata))
+ }
+
/**
* Casts the column to a different data type.
* {{{
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dc9912b52dcab..e59cf9b9e037b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1210,38 +1210,56 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns a Catalyst Schema for the given java bean class.
*/
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
+ val (dataType, _) = inferDataType(beanClass)
+ dataType.asInstanceOf[StructType].fields.map { f =>
+ AttributeReference(f.name, f.dataType, f.nullable)()
+ }
+ }
+
+ /**
+ * Infers the corresponding SQL data type of a Java class.
+ * @param clazz Java class
+ * @return (SQL data type, nullable)
+ */
+ private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
- val beanInfo = Introspector.getBeanInfo(beanClass)
-
- // Note: The ordering of elements may differ from when the schema is inferred in Scala.
- // This is because beanInfo.getPropertyDescriptors gives no guarantees about
- // element ordering.
- val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
- fields.map { property =>
- val (dataType, nullable) = property.getPropertyType match {
- case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
- (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
- case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
- case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
- case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
- case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
- case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
- case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
- case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
- case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
-
- case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
- case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
- case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
- case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
- case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
- case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
- case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
- case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
- case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
- case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
- }
- AttributeReference(property.getName, dataType, nullable)()
+ clazz match {
+ case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
+
+ case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
+ case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
+ case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
+ case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
+ case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
+ case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
+ case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
+ case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
+
+ case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
+ case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
+ case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
+ case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
+ case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
+ case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
+ case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
+
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
+ case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
+ case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
+
+ case c: Class[_] if c.isArray =>
+ val (dataType, nullable) = inferDataType(c.getComponentType)
+ (ArrayType(dataType, nullable), true)
+
+ case _ =>
+ val beanInfo = Introspector.getBeanInfo(clazz)
+ val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+ val fields = properties.map { property =>
+ val (dataType, nullable) = inferDataType(property.getPropertyType)
+ new StructField(property.getName, dataType, nullable)
+ }
+ (new StructType(fields), true)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 20c9bc3e75542..1f5251a20376f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.util.MutablePair
+import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.util.collection.ExternalSorter
/**
@@ -194,7 +194,9 @@ case class ExternalSort(
val ordering = newOrdering(sortOrder, child.output)
val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering))
sorter.insertAll(iterator.map(r => (r, null)))
- sorter.iterator.map(_._1)
+ val baseIterator = sorter.iterator.map(_._1)
+ // TODO(marmbrus): The complex type signature below thwarts inference for no reason.
+ CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop())
}, preservesPartitioning = true)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index d2e807d3a69b6..eb46b46ca5bf4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -21,7 +21,7 @@ import scala.language.existentials
import scala.language.implicitConversions
import org.apache.spark.Logging
-import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
+import org.apache.spark.sql.{AnalysisException, SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -204,19 +204,25 @@ private[sql] object ResolvedDataSource {
provider: String,
options: Map[String, String]): ResolvedDataSource = {
val clazz: Class[_] = lookupDataSource(provider)
+ def className = clazz.getCanonicalName
val relation = userSpecifiedSchema match {
case Some(schema: StructType) => clazz.newInstance() match {
case dataSource: SchemaRelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
- sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.")
+ throw new AnalysisException(s"$className does not allow user-specified schemas.")
+ case _ =>
+ throw new AnalysisException(s"$className is not a RelationProvider.")
}
case None => clazz.newInstance() match {
case dataSource: RelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
- sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.")
+ throw new AnalysisException(
+ s"A schema needs to be specified when using $className.")
+ case _ =>
+ throw new AnalysisException(s"$className is not a RelationProvider.")
}
}
new ResolvedDataSource(clazz, relation)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 2d586f784ac5a..1ff2d5a190521 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -17,29 +17,39 @@
package test.org.apache.spark.sql;
+import java.io.Serializable;
+import java.util.Arrays;
+
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
+import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
-import static org.apache.spark.sql.functions.*;
+import org.apache.spark.sql.types.*;
+import static org.apache.spark.sql.functions.*;
public class JavaDataFrameSuite {
+ private transient JavaSparkContext jsc;
private transient SQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
TestData$.MODULE$.testData();
+ jsc = new JavaSparkContext(TestSQLContext.sparkContext());
context = TestSQLContext$.MODULE$;
}
@After
public void tearDown() {
+ jsc = null;
context = null;
}
@@ -90,4 +100,33 @@ public void testShow() {
df.show();
df.show(1000);
}
+
+ public static class Bean implements Serializable {
+ private double a = 0.0;
+ private Integer[] b = new Integer[]{0, 1};
+
+ public double getA() {
+ return a;
+ }
+
+ public Integer[] getB() {
+ return b;
+ }
+ }
+
+ @Test
+ public void testCreateDataFrameFromJavaBeans() {
+ Bean bean = new Bean();
+ JavaRDD rdd = jsc.parallelize(Arrays.asList(bean));
+ DataFrame df = context.createDataFrame(rdd, Bean.class);
+ StructType schema = df.schema();
+ Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
+ schema.apply("a"));
+ Assert.assertEquals(
+ new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
+ schema.apply("b"));
+ Row first = df.select("a", "b").first();
+ Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
+ Assert.assertArrayEquals(bean.getB(), first.getAs(1));
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index a53ae97d6243a..bc8fae100db6a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,12 +17,10 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.expressions.NamedExpression
-import org.apache.spark.sql.catalyst.plans.logical.{Project, NoRelation}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
-import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}
+import org.apache.spark.sql.types._
class ColumnExpressionSuite extends QueryTest {
@@ -322,4 +320,15 @@ class ColumnExpressionSuite extends QueryTest {
assert('key.desc == 'key.desc)
assert('key.desc != 'key.asc)
}
+
+ test("alias with metadata") {
+ val metadata = new MetadataBuilder()
+ .putString("originName", "value")
+ .build()
+ val schema = testData
+ .select($"*", col("value").as("abc", metadata))
+ .schema
+ assert(schema("value").metadata === Metadata.empty)
+ assert(schema("abc").metadata === metadata)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ff441ef26f9c0..c30ed694a62f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -108,6 +108,13 @@ class DataFrameSuite extends QueryTest {
)
}
+ test("self join with aliases") {
+ val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str")
+ checkAnswer(
+ df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(),
+ Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
+ }
+
test("explode") {
val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters")
val df2 =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index dd0948ad824be..e4dee87849fd4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -34,7 +34,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("equi-join is hash-join") {
val x = testData2.as("x")
val y = testData2.as("y")
- val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed
+ val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan
val planned = planner.HashJoin(join)
assert(planned.size === 1)
}
@@ -109,7 +109,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("multiple-key equi-join is hash-join") {
val x = testData2.as("x")
val y = testData2.as("y")
- val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed
+ val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan
val planned = planner.HashJoin(join)
assert(planned.size === 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index be105c6e83594..d615542ab50a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -50,4 +50,10 @@ class UDFSuite extends QueryTest {
.select($"ret.f1").head().getString(0)
assert(result === "test")
}
+
+ test("udf that is transformed") {
+ udf.register("makeStruct", (x: Int, y: Int) => (x, y))
+ // 1 + 1 is constant folded causing a transformation.
+ assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 4c5eb48661f7d..d1a99555e90c6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -459,7 +459,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") =>
val parquetRelation = convertToParquetRelation(relation)
val attributedRewrites = relation.output.zip(parquetRelation.output)
- (relation, parquetRelation, attributedRewrites)
+ (relation -> relation.output, parquetRelation, attributedRewrites)
// Write path
case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _)
@@ -470,7 +470,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") =>
val parquetRelation = convertToParquetRelation(relation)
val attributedRewrites = relation.output.zip(parquetRelation.output)
- (relation, parquetRelation, attributedRewrites)
+ (relation -> relation.output, parquetRelation, attributedRewrites)
// Read path
case p @ PhysicalOperation(_, _, relation: MetastoreRelation)
@@ -479,33 +479,35 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") =>
val parquetRelation = convertToParquetRelation(relation)
val attributedRewrites = relation.output.zip(parquetRelation.output)
- (relation, parquetRelation, attributedRewrites)
+ (relation -> relation.output, parquetRelation, attributedRewrites)
}
+ // Quick fix for SPARK-6450: Notice that we're using both the MetastoreRelation instances and
+ // their output attributes as the key of the map. This is because MetastoreRelation.equals
+ // doesn't take output attributes into account, thus multiple MetastoreRelation instances
+ // pointing to the same table get collapsed into a single entry in the map. A proper fix for
+ // this should be overriding equals & hashCode in MetastoreRelation.
val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap
val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _))
// Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes
// attribute IDs referenced in other nodes.
plan.transformUp {
- case r: MetastoreRelation if relationMap.contains(r) => {
- val parquetRelation = relationMap(r)
- val withAlias =
- r.alias.map(a => Subquery(a, parquetRelation)).getOrElse(
- Subquery(r.tableName, parquetRelation))
+ case r: MetastoreRelation if relationMap.contains(r -> r.output) =>
+ val parquetRelation = relationMap(r -> r.output)
+ val alias = r.alias.getOrElse(r.tableName)
+ Subquery(alias, parquetRelation)
- withAlias
- }
case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite)
- if relationMap.contains(r) => {
- val parquetRelation = relationMap(r)
+ if relationMap.contains(r -> r.output) =>
+ val parquetRelation = relationMap(r -> r.output)
InsertIntoTable(parquetRelation, partition, child, overwrite)
- }
+
case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite)
- if relationMap.contains(r) => {
- val parquetRelation = relationMap(r)
+ if relationMap.contains(r -> r.output) =>
+ val parquetRelation = relationMap(r -> r.output)
InsertIntoTable(parquetRelation, partition, child, overwrite)
- }
+
case other => other.transformExpressions {
case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 51775eb4cd6a0..c45c4ad70fae9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -55,37 +55,8 @@ private[hive] case object NativePlaceholder extends Command
/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
private[hive] object HiveQl {
protected val nativeCommands = Seq(
- "TOK_DESCFUNCTION",
- "TOK_DESCDATABASE",
- "TOK_SHOW_CREATETABLE",
- "TOK_SHOWCOLUMNS",
- "TOK_SHOW_TABLESTATUS",
- "TOK_SHOWDATABASES",
- "TOK_SHOWFUNCTIONS",
- "TOK_SHOWINDEXES",
- "TOK_SHOWINDEXES",
- "TOK_SHOWPARTITIONS",
- "TOK_SHOW_TBLPROPERTIES",
-
- "TOK_LOCKTABLE",
- "TOK_SHOWLOCKS",
- "TOK_UNLOCKTABLE",
-
- "TOK_SHOW_ROLES",
- "TOK_CREATEROLE",
- "TOK_DROPROLE",
- "TOK_GRANT",
- "TOK_GRANT_ROLE",
- "TOK_REVOKE",
- "TOK_SHOW_GRANT",
- "TOK_SHOW_ROLE_GRANT",
- "TOK_SHOW_SET_ROLE",
-
- "TOK_CREATEFUNCTION",
- "TOK_DROPFUNCTION",
-
- "TOK_ALTERDATABASE_PROPERTIES",
"TOK_ALTERDATABASE_OWNER",
+ "TOK_ALTERDATABASE_PROPERTIES",
"TOK_ALTERINDEX_PROPERTIES",
"TOK_ALTERINDEX_REBUILD",
"TOK_ALTERTABLE_ADDCOLS",
@@ -102,28 +73,61 @@ private[hive] object HiveQl {
"TOK_ALTERTABLE_SKEWED",
"TOK_ALTERTABLE_TOUCH",
"TOK_ALTERTABLE_UNARCHIVE",
- "TOK_CREATEDATABASE",
- "TOK_CREATEFUNCTION",
- "TOK_CREATEINDEX",
- "TOK_DROPDATABASE",
- "TOK_DROPINDEX",
- "TOK_DROPTABLE_PROPERTIES",
- "TOK_MSCK",
-
"TOK_ALTERVIEW_ADDPARTS",
"TOK_ALTERVIEW_AS",
"TOK_ALTERVIEW_DROPPARTS",
"TOK_ALTERVIEW_PROPERTIES",
"TOK_ALTERVIEW_RENAME",
+
+ "TOK_CREATEDATABASE",
+ "TOK_CREATEFUNCTION",
+ "TOK_CREATEINDEX",
+ "TOK_CREATEROLE",
"TOK_CREATEVIEW",
- "TOK_DROPVIEW_PROPERTIES",
+
+ "TOK_DESCDATABASE",
+ "TOK_DESCFUNCTION",
+
+ "TOK_DROPDATABASE",
+ "TOK_DROPFUNCTION",
+ "TOK_DROPINDEX",
+ "TOK_DROPROLE",
+ "TOK_DROPTABLE_PROPERTIES",
"TOK_DROPVIEW",
-
+ "TOK_DROPVIEW_PROPERTIES",
+
"TOK_EXPORT",
+
+ "TOK_GRANT",
+ "TOK_GRANT_ROLE",
+
"TOK_IMPORT",
+
"TOK_LOAD",
-
- "TOK_SWITCHDATABASE"
+
+ "TOK_LOCKTABLE",
+
+ "TOK_MSCK",
+
+ "TOK_REVOKE",
+
+ "TOK_SHOW_CREATETABLE",
+ "TOK_SHOW_GRANT",
+ "TOK_SHOW_ROLE_GRANT",
+ "TOK_SHOW_ROLES",
+ "TOK_SHOW_SET_ROLE",
+ "TOK_SHOW_TABLESTATUS",
+ "TOK_SHOW_TBLPROPERTIES",
+ "TOK_SHOWCOLUMNS",
+ "TOK_SHOWDATABASES",
+ "TOK_SHOWFUNCTIONS",
+ "TOK_SHOWINDEXES",
+ "TOK_SHOWLOCKS",
+ "TOK_SHOWPARTITIONS",
+
+ "TOK_SWITCHDATABASE",
+
+ "TOK_UNLOCKTABLE"
)
// Commands that we do not need to explain.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index af309c0c6ce2c..3563472c7ae81 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable}
import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc}
import org.apache.hadoop.hive.serde2.Deserializer
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}
@@ -116,7 +116,7 @@ class HadoopTableReader(
val hconf = broadcastedHiveConf.value.value
val deserializer = deserializerClass.newInstance()
deserializer.initialize(hconf, tableDesc.getProperties)
- HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow)
+ HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer)
}
deserializedHadoopRDD
@@ -189,9 +189,13 @@ class HadoopTableReader(
val hconf = broadcastedHiveConf.value.value
val deserializer = localDeserializer.newInstance()
deserializer.initialize(hconf, partProps)
+ // get the table deserializer
+ val tableSerDe = tableDesc.getDeserializerClass.newInstance()
+ tableSerDe.initialize(hconf, tableDesc.getProperties)
// fill the non partition key attributes
- HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow)
+ HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs,
+ mutableRow, tableSerDe)
}
}.toSeq
@@ -261,25 +265,36 @@ private[hive] object HadoopTableReader extends HiveInspectors {
* Transform all given raw `Writable`s into `Row`s.
*
* @param iterator Iterator of all `Writable`s to be transformed
- * @param deserializer The `Deserializer` associated with the input `Writable`
+ * @param rawDeser The `Deserializer` associated with the input `Writable`
* @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding
* positions in the output schema
* @param mutableRow A reusable `MutableRow` that should be filled
+ * @param tableDeser Table Deserializer
* @return An `Iterator[Row]` transformed from `iterator`
*/
def fillObject(
iterator: Iterator[Writable],
- deserializer: Deserializer,
+ rawDeser: Deserializer,
nonPartitionKeyAttrs: Seq[(Attribute, Int)],
- mutableRow: MutableRow): Iterator[Row] = {
+ mutableRow: MutableRow,
+ tableDeser: Deserializer): Iterator[Row] = {
+
+ val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) {
+ rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector]
+ } else {
+ HiveShim.getConvertedOI(
+ rawDeser.getObjectInspector,
+ tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector]
+ }
- val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector]
val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) =>
soi.getStructFieldRef(attr.name) -> ordinal
}.unzip
- // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern
- // matching and branching costs per row.
+ /**
+ * Builds specific unwrappers ahead of time according to object inspector
+ * types to avoid pattern matching and branching costs per row.
+ */
val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map {
_.getFieldObjectInspector match {
case oi: BooleanObjectInspector =>
@@ -316,9 +331,11 @@ private[hive] object HadoopTableReader extends HiveInspectors {
}
}
+ val converter = ObjectInspectorConverters.getConverter(rawDeser.getObjectInspector, soi)
+
// Map each tuple to a row object
iterator.map { value =>
- val raw = deserializer.deserialize(value)
+ val raw = converter.convert(rawDeser.deserialize(value))
var i = 0
while (i < fieldRefs.length) {
val fieldValue = soi.getStructFieldData(raw, fieldRefs(i))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index bfe43373d9534..47305571e579e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -375,9 +375,8 @@ private[hive] case class HiveUdafFunction(
private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
- // Cast required to avoid type inference selecting a deprecated Hive API.
private val buffer =
- function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer]
+ function.getNewAggregationBuffer
override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index dc61e9d2e3522..a3497eadd67f6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -23,6 +23,7 @@ import java.util.{Set => JavaSet}
import org.apache.hadoop.hive.ql.exec.FunctionRegistry
import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat}
import org.apache.hadoop.hive.ql.metadata.Table
+import org.apache.hadoop.hive.ql.parse.VariableSubstitution
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.serde2.RegexSerDe
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
@@ -153,8 +154,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
val describedTable = "DESCRIBE (\\w+)".r
+ val vs = new VariableSubstitution()
+
+ // we should substitute variables in hql to pass the text to parseSql() as a parameter.
+ // Hive parser need substituted text. HiveContext.sql() does this but return a DataFrame,
+ // while we need a logicalPlan so we cannot reuse that.
protected[hive] class HiveQLQueryExecution(hql: String)
- extends this.QueryExecution(HiveQl.parseSql(hql)) {
+ extends this.QueryExecution(HiveQl.parseSql(vs.substitute(hiveconf, hql))) {
def hiveExec(): Seq[String] = runSqlHive(hql)
override def toString: String = hql + "\n" + super.toString
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index 44d24273e722a..221a0c263d36c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -92,12 +92,12 @@ class CachedTableSuite extends QueryTest {
}
test("Drop cached table") {
- sql("CREATE TABLE test(a INT)")
- cacheTable("test")
- sql("SELECT * FROM test").collect()
- sql("DROP TABLE test")
+ sql("CREATE TABLE cachedTableTest(a INT)")
+ cacheTable("cachedTableTest")
+ sql("SELECT * FROM cachedTableTest").collect()
+ sql("DROP TABLE cachedTableTest")
intercept[AnalysisException] {
- sql("SELECT * FROM test").collect()
+ sql("SELECT * FROM cachedTableTest").collect()
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
index f04437c595bf6..968557c9c4686 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -19,12 +19,29 @@ package org.apache.spark.sql.hive
import java.io.{OutputStream, PrintStream}
+import scala.util.Try
+
+import org.scalatest.BeforeAndAfter
+
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.{AnalysisException, QueryTest}
-import scala.util.Try
-class ErrorPositionSuite extends QueryTest {
+class ErrorPositionSuite extends QueryTest with BeforeAndAfter {
+
+ before {
+ Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes")
+ }
+
+ positionTest("ambiguous attribute reference 1",
+ "SELECT a from dupAttributes", "a")
+
+ positionTest("ambiguous attribute reference 2",
+ "SELECT a, b from dupAttributes", "a")
+
+ positionTest("ambiguous attribute reference 3",
+ "SELECT b, a from dupAttributes", "a")
positionTest("unresolved attribute 1",
"SELECT x FROM src", "x")
@@ -127,6 +144,10 @@ class ErrorPositionSuite extends QueryTest {
val error = intercept[AnalysisException] {
quietly(sql(query))
}
+
+ assert(!error.getMessage.contains("Seq("))
+ assert(!error.getMessage.contains("List("))
+
val (line, expectedLineNum) = query.split("\n").zipWithIndex.collect {
case (l, i) if l.contains(token) => (l, i + 1)
}.headOption.getOrElse(sys.error(s"Invalid test. Token $token not in $query"))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 381cd2a29123e..aa6fb42de7f88 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -32,9 +32,12 @@ import org.apache.spark.sql.hive.test.TestHive._
case class TestData(key: Int, value: String)
+case class ThreeCloumntable(key: Int, value: String, key1: String)
+
class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
import org.apache.spark.sql.hive.test.TestHive.implicits._
+
val testData = TestHive.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
@@ -186,4 +189,43 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
sql("DROP TABLE hiveTableWithStructValue")
}
+
+ test("SPARK-5498:partition schema does not match table schema") {
+ val testData = TestHive.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(i, i.toString))).toDF()
+ testData.registerTempTable("testData")
+
+ val testDatawithNull = TestHive.sparkContext.parallelize(
+ (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))).toDF()
+
+ val tmpDir = Files.createTempDir()
+ sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ")
+ sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData")
+
+ // test schema the same between partition and table
+ sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT")
+ checkAnswer(sql("select key,value from table_with_partition where ds='1' "),
+ testData.collect.toSeq
+ )
+
+ // test difference type of field
+ sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT")
+ checkAnswer(sql("select key,value from table_with_partition where ds='1' "),
+ testData.collect.toSeq
+ )
+
+ // add column to table
+ sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)")
+ checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "),
+ testDatawithNull.collect.toSeq
+ )
+
+ // change column name to table
+ sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT")
+ checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "),
+ testData.collect.toSeq
+ )
+
+ sql("DROP TABLE table_with_partition")
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index ff2e6ea9ea51d..e5ad0bf552073 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -579,7 +579,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
Row(3) :: Row(4) :: Nil
)
- table("test_parquet_ctas").queryExecution.analyzed match {
+ table("test_parquet_ctas").queryExecution.optimizedPlan match {
case LogicalRelation(p: ParquetRelation2) => // OK
case _ =>
fail(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index cb405f56bf53d..d7c5d1a25a82b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -22,7 +22,7 @@ import java.util
import java.util.Properties
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
+import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF}
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
@@ -93,6 +93,15 @@ class HiveUdfSuite extends QueryTest {
sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf")
}
+ test("SPARK-6409 UDAFAverage test") {
+ sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'")
+ checkAnswer(
+ sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"),
+ Seq(Row(1.0, 260.182)))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg")
+ TestHive.reset()
+ }
+
test("SPARK-2693 udaf aggregates test") {
checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src").collect().toSeq)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index d891c4e8903d9..432d65a874518 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -292,7 +292,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase {
Seq(Row(1, "str1"))
)
- table("test_parquet_ctas").queryExecution.analyzed match {
+ table("test_parquet_ctas").queryExecution.optimizedPlan match {
case LogicalRelation(p: ParquetRelation2) => // OK
case _ =>
fail(
@@ -365,6 +365,31 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase {
sql("DROP TABLE IF EXISTS test_insert_parquet")
}
+
+ test("SPARK-6450 regression test") {
+ sql(
+ """CREATE TABLE IF NOT EXISTS ms_convert (key INT)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ """.stripMargin)
+
+ // This shouldn't throw AnalysisException
+ val analyzed = sql(
+ """SELECT key FROM ms_convert
+ |UNION ALL
+ |SELECT key FROM ms_convert
+ """.stripMargin).queryExecution.analyzed
+
+ assertResult(2) {
+ analyzed.collect {
+ case r @ LogicalRelation(_: ParquetRelation2) => r
+ }.size
+ }
+
+ sql("DROP TABLE ms_convert")
+ }
}
class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase {
diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
index 30646ddbc29d8..0ed93c2c5b1fa 100644
--- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
+++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
@@ -34,7 +34,7 @@ import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.stats.StatsSetupConst
import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo}
-import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, PrimitiveObjectInspector}
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, ObjectInspector, PrimitiveObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory
import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory}
@@ -210,7 +210,7 @@ private[hive] object HiveShim {
def getDataLocationPath(p: Partition) = p.getPartitionPath
- def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl)
+ def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl)
def compatibilityBlackList = Seq(
"decimal_.*",
@@ -244,6 +244,12 @@ private[hive] object HiveShim {
}
}
+ def getConvertedOI(
+ inputOI: ObjectInspector,
+ outputOI: ObjectInspector): ObjectInspector = {
+ ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, true)
+ }
+
def prepareWritable(w: Writable): Writable = {
w
}
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index f9fcbdae15745..7577309900209 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.hive
+import java.util
import java.util.{ArrayList => JArrayList}
import java.util.Properties
import java.rmi.server.UID
@@ -38,7 +39,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
-import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector}
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector}
import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable
@@ -400,7 +401,11 @@ private[hive] object HiveShim {
Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale())
}
}
-
+
+ def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = {
+ ObjectInspectorConverters.getConvertedOI(inputOI, outputOI)
+ }
+
/*
* Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that
* is needed to initialize before serialization.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index db64e11e16304..f73b463d07779 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -67,12 +67,12 @@ object Checkpoint extends Logging {
val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r
/** Get the checkpoint file for the given checkpoint time */
- def checkpointFile(checkpointDir: String, checkpointTime: Time) = {
+ def checkpointFile(checkpointDir: String, checkpointTime: Time): Path = {
new Path(checkpointDir, PREFIX + checkpointTime.milliseconds)
}
/** Get the checkpoint backup file for the given checkpoint time */
- def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = {
+ def checkpointBackupFile(checkpointDir: String, checkpointTime: Time): Path = {
new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk")
}
@@ -232,6 +232,8 @@ object CheckpointReader extends Logging {
def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] =
{
val checkpointPath = new Path(checkpointDir)
+
+ // TODO(rxin): Why is this a def?!
def fs = checkpointPath.getFileSystem(hadoopConf)
// Try to find the checkpoint files
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index 0e285d6088ec1..175140481e5ae 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -100,11 +100,11 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
}
}
- def getInputStreams() = this.synchronized { inputStreams.toArray }
+ def getInputStreams(): Array[InputDStream[_]] = this.synchronized { inputStreams.toArray }
- def getOutputStreams() = this.synchronized { outputStreams.toArray }
+ def getOutputStreams(): Array[DStream[_]] = this.synchronized { outputStreams.toArray }
- def getReceiverInputStreams() = this.synchronized {
+ def getReceiverInputStreams(): Array[ReceiverInputDStream[_]] = this.synchronized {
inputStreams.filter(_.isInstanceOf[ReceiverInputDStream[_]])
.map(_.asInstanceOf[ReceiverInputDStream[_]])
.toArray
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala
index a0d8fb5ab93ec..3249bb348981f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala
@@ -55,7 +55,6 @@ case class Duration (private val millis: Long) {
def div(that: Duration): Double = this / that
-
def isMultipleOf(that: Duration): Boolean =
(this.millis % that.millis == 0)
@@ -71,7 +70,7 @@ case class Duration (private val millis: Long) {
def milliseconds: Long = millis
- def prettyPrint = Utils.msDurationToString(millis)
+ def prettyPrint: String = Utils.msDurationToString(millis)
}
@@ -80,7 +79,7 @@ case class Duration (private val millis: Long) {
* a given number of milliseconds.
*/
object Milliseconds {
- def apply(milliseconds: Long) = new Duration(milliseconds)
+ def apply(milliseconds: Long): Duration = new Duration(milliseconds)
}
/**
@@ -88,7 +87,7 @@ object Milliseconds {
* a given number of seconds.
*/
object Seconds {
- def apply(seconds: Long) = new Duration(seconds * 1000)
+ def apply(seconds: Long): Duration = new Duration(seconds * 1000)
}
/**
@@ -96,7 +95,7 @@ object Seconds {
* a given number of minutes.
*/
object Minutes {
- def apply(minutes: Long) = new Duration(minutes * 60000)
+ def apply(minutes: Long): Duration = new Duration(minutes * 60000)
}
// Java-friendlier versions of the objects above.
@@ -107,16 +106,16 @@ object Durations {
/**
* @return [[org.apache.spark.streaming.Duration]] representing given number of milliseconds.
*/
- def milliseconds(milliseconds: Long) = Milliseconds(milliseconds)
+ def milliseconds(milliseconds: Long): Duration = Milliseconds(milliseconds)
/**
* @return [[org.apache.spark.streaming.Duration]] representing given number of seconds.
*/
- def seconds(seconds: Long) = Seconds(seconds)
+ def seconds(seconds: Long): Duration = Seconds(seconds)
/**
* @return [[org.apache.spark.streaming.Duration]] representing given number of minutes.
*/
- def minutes(minutes: Long) = Minutes(minutes)
+ def minutes(minutes: Long): Duration = Minutes(minutes)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala
index ad4f3fdd14ad6..3f5be785e1b1a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala
@@ -39,18 +39,18 @@ class Interval(val beginTime: Time, val endTime: Time) {
this.endTime < that.endTime
}
- def <= (that: Interval) = (this < that || this == that)
+ def <= (that: Interval): Boolean = (this < that || this == that)
- def > (that: Interval) = !(this <= that)
+ def > (that: Interval): Boolean = !(this <= that)
- def >= (that: Interval) = !(this < that)
+ def >= (that: Interval): Boolean = !(this < that)
- override def toString = "[" + beginTime + ", " + endTime + "]"
+ override def toString: String = "[" + beginTime + ", " + endTime + "]"
}
private[streaming]
object Interval {
- def currentInterval(duration: Duration): Interval = {
+ def currentInterval(duration: Duration): Interval = {
val time = new Time(System.currentTimeMillis)
val intervalBegin = time.floor(duration)
new Interval(intervalBegin, intervalBegin + duration)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 543224d4b07bc..f57f295874645 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -188,7 +188,7 @@ class StreamingContext private[streaming] (
/**
* Return the associated Spark context
*/
- def sparkContext = sc
+ def sparkContext: SparkContext = sc
/**
* Set each DStreams in this context to remember RDDs it generated in the last given duration.
@@ -596,7 +596,8 @@ object StreamingContext extends Logging {
@deprecated("Replaced by implicit functions in the DStream companion object. This is " +
"kept here only for backward compatibility.", "1.3.0")
def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])
- (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = {
+ (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null)
+ : PairDStreamFunctions[K, V] = {
DStream.toPairDStreamFunctions(stream)(kt, vt, ord)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index 2eabdd9387913..73030e15c5661 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -415,8 +415,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
implicit val cmv2: ClassTag[V2] = fakeClassTag
implicit val cmw: ClassTag[W] = fakeClassTag
- def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] =
+ def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = {
transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd
+ }
dstream.transformWith[(K2, V2), W](other.dstream, scalaTransform(_, _, _))
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
index 7053f47ec69a2..4c28654ef6413 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -176,11 +176,11 @@ private[python] abstract class PythonDStream(
val func = new TransformFunction(pfunc)
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
- val asJavaDStream = JavaDStream.fromDStream(this)
+ val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this)
}
/**
@@ -212,7 +212,7 @@ private[python] class PythonTransformed2DStream(
val func = new TransformFunction(pfunc)
- override def dependencies = List(parent, parent2)
+ override def dependencies: List[DStream[_]] = List(parent, parent2)
override def slideDuration: Duration = parent.slideDuration
@@ -223,7 +223,7 @@ private[python] class PythonTransformed2DStream(
func(Some(rdd1), Some(rdd2), validTime)
}
- val asJavaDStream = JavaDStream.fromDStream(this)
+ val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this)
}
/**
@@ -260,12 +260,15 @@ private[python] class PythonReducedWindowedDStream(
extends PythonDStream(parent, preduceFunc) {
super.persist(StorageLevel.MEMORY_ONLY)
- override val mustCheckpoint = true
- val invReduceFunc = new TransformFunction(pinvReduceFunc)
+ override val mustCheckpoint: Boolean = true
+
+ val invReduceFunc: TransformFunction = new TransformFunction(pinvReduceFunc)
def windowDuration: Duration = _windowDuration
+
override def slideDuration: Duration = _slideDuration
+
override def parentRememberDuration: Duration = rememberDuration + windowDuration
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index b874f561c12eb..795c5aa6d585b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -104,7 +104,7 @@ abstract class DStream[T: ClassTag] (
private[streaming] def parentRememberDuration = rememberDuration
/** Return the StreamingContext associated with this DStream */
- def context = ssc
+ def context: StreamingContext = ssc
/* Set the creation call site */
private[streaming] val creationSite = DStream.getCreationSite()
@@ -619,14 +619,16 @@ abstract class DStream[T: ClassTag] (
* operator, so this DStream will be registered as an output stream and there materialized.
*/
def print(num: Int) {
- def foreachFunc = (rdd: RDD[T], time: Time) => {
- val firstNum = rdd.take(num + 1)
- println ("-------------------------------------------")
- println ("Time: " + time)
- println ("-------------------------------------------")
- firstNum.take(num).foreach(println)
- if (firstNum.size > num) println("...")
- println()
+ def foreachFunc: (RDD[T], Time) => Unit = {
+ (rdd: RDD[T], time: Time) => {
+ val firstNum = rdd.take(num + 1)
+ println("-------------------------------------------")
+ println("Time: " + time)
+ println("-------------------------------------------")
+ firstNum.take(num).foreach(println)
+ if (firstNum.size > num) println("...")
+ println()
+ }
}
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
index 0dc72790fbdbd..39fd21342813e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
@@ -114,7 +114,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
}
}
- override def toString() = {
+ override def toString: String = {
"[\n" + currentCheckpointFiles.size + " checkpoint files \n" +
currentCheckpointFiles.mkString("\n") + "\n]"
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index 22de8c02e63c8..66d519171fd76 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -298,7 +298,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]](
private[streaming]
class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) {
- def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]]
+ private def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]]
override def update(time: Time) {
hadoopFiles.clear()
@@ -320,7 +320,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]](
}
}
- override def toString() = {
+ override def toString: String = {
"[\n" + hadoopFiles.size + " file sets\n" +
hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]"
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
index c81534ae584ea..fcd5216f101af 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
@@ -27,7 +27,7 @@ class FilteredDStream[T: ClassTag](
filterFunc: T => Boolean
) extends DStream[T](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
index 658623455498c..9d09a3baf37ca 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
@@ -28,7 +28,7 @@ class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag](
flatMapValueFunc: V => TraversableOnce[U]
) extends DStream[(K, U)](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
index c7bb2833eabb8..475ea2d2d4f38 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
@@ -27,7 +27,7 @@ class FlatMappedDStream[T: ClassTag, U: ClassTag](
flatMapFunc: T => Traversable[U]
) extends DStream[U](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
index 1361c30395b57..685a32e1d280d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
@@ -28,7 +28,7 @@ class ForEachDStream[T: ClassTag] (
foreachFunc: (RDD[T], Time) => Unit
) extends DStream[Unit](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
index a9bb51f054048..dbb295fe54f71 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
@@ -25,7 +25,7 @@ private[streaming]
class GlommedDStream[T: ClassTag](parent: DStream[T])
extends DStream[Array[T]](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index aa1993f0580a8..e652702e213ef 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -61,7 +61,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
}
}
- override def dependencies = List()
+ override def dependencies: List[DStream[_]] = List()
override def slideDuration: Duration = {
if (ssc == null) throw new Exception("ssc is null")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
index 3d8ee29df1e82..5994bc1e23f2b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
@@ -28,7 +28,7 @@ class MapPartitionedDStream[T: ClassTag, U: ClassTag](
preservePartitioning: Boolean
) extends DStream[U](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
index 7aea1f945d9db..954d2eb4a7b00 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
@@ -28,7 +28,7 @@ class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag](
mapValueFunc: V => U
) extends DStream[(K, U)](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
index 02704a8d1c2e0..fa14b2e897c3e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
@@ -27,7 +27,7 @@ class MappedDStream[T: ClassTag, U: ClassTag] (
mapFunc: T => U
) extends DStream[U](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
index c0a5af0b65cc3..1385ccbf56ee5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
@@ -52,7 +52,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
// Reduce each batch of data using reduceByKey which will be further reduced by window
// by ReducedWindowedDStream
- val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
+ private val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
// Persist RDDs to memory by default as these RDDs are going to be reused.
super.persist(StorageLevel.MEMORY_ONLY_SER)
@@ -60,7 +60,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
def windowDuration: Duration = _windowDuration
- override def dependencies = List(reducedStream)
+ override def dependencies: List[DStream[_]] = List(reducedStream)
override def slideDuration: Duration = _slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
index 880a89bc36895..7757ccac09a58 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
@@ -33,7 +33,7 @@ class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag](
mapSideCombine: Boolean = true
) extends DStream[(K,C)] (parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
index ebb04dd35b9a2..de8718d0a80fe 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
@@ -36,7 +36,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
super.persist(StorageLevel.MEMORY_ONLY_SER)
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
index 71b61856e23c0..5d46ca0715ffd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
@@ -32,7 +32,7 @@ class TransformedDStream[U: ClassTag] (
require(parents.map(_.slideDuration).distinct.size == 1,
"Some of the DStreams have different slide durations")
- override def dependencies = parents.toList
+ override def dependencies: List[DStream[_]] = parents.toList
override def slideDuration: Duration = parents.head.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
index abbc40befa95b..9405dbaa12329 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
@@ -33,17 +33,17 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
require(parents.map(_.slideDuration).distinct.size == 1,
"Some of the DStreams have different slide durations")
- override def dependencies = parents.toList
+ override def dependencies: List[DStream[_]] = parents.toList
override def slideDuration: Duration = parents.head.slideDuration
override def compute(validTime: Time): Option[RDD[T]] = {
val rdds = new ArrayBuffer[RDD[T]]()
- parents.map(_.getOrCompute(validTime)).foreach(_ match {
+ parents.map(_.getOrCompute(validTime)).foreach {
case Some(rdd) => rdds += rdd
case None => throw new Exception("Could not generate RDD from a parent for unifying at time "
+ validTime)
- })
+ }
if (rdds.size > 0) {
Some(new UnionRDD(ssc.sc, rdds))
} else {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
index 775b6bfd065c0..899865a906c27 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
@@ -46,7 +46,7 @@ class WindowedDStream[T: ClassTag](
def windowDuration: Duration = _windowDuration
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = _slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
index dd1e96334952f..93caa4ba35c7f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
@@ -117,8 +117,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
override def getPreferredLocations(split: Partition): Seq[String] = {
val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition]
val blockLocations = getBlockIdLocations().get(partition.blockId)
- def segmentLocations = HdfsUtils.getFileSegmentLocations(
- partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig)
- blockLocations.getOrElse(segmentLocations)
+ blockLocations.getOrElse(
+ HdfsUtils.getFileSegmentLocations(
+ partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig))
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
index a7d63bd4f2dbf..cd309788a7717 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
@@ -17,6 +17,7 @@
package org.apache.spark.streaming.receiver
+import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.duration._
@@ -25,10 +26,10 @@ import scala.reflect.ClassTag
import akka.actor._
import akka.actor.SupervisorStrategy.{Escalate, Restart}
+
import org.apache.spark.{Logging, SparkEnv}
-import org.apache.spark.storage.StorageLevel
-import java.nio.ByteBuffer
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.storage.StorageLevel
/**
* :: DeveloperApi ::
@@ -149,13 +150,13 @@ private[streaming] class ActorReceiver[T: ClassTag](
class Supervisor extends Actor {
override val supervisorStrategy = receiverSupervisorStrategy
- val worker = context.actorOf(props, name)
+ private val worker = context.actorOf(props, name)
logInfo("Started receiver worker at:" + worker.path)
- val n: AtomicInteger = new AtomicInteger(0)
- val hiccups: AtomicInteger = new AtomicInteger(0)
+ private val n: AtomicInteger = new AtomicInteger(0)
+ private val hiccups: AtomicInteger = new AtomicInteger(0)
- def receive = {
+ override def receive: PartialFunction[Any, Unit] = {
case IteratorData(iterator) =>
logDebug("received iterator")
@@ -189,13 +190,12 @@ private[streaming] class ActorReceiver[T: ClassTag](
}
}
- def onStart() = {
+ def onStart(): Unit = {
supervisor
logInfo("Supervision tree for receivers initialized at:" + supervisor.path)
-
}
- def onStop() = {
+ def onStop(): Unit = {
supervisor ! PoisonPill
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
index ee5e639b26d91..42514d8b47dcf 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
@@ -120,7 +120,7 @@ private[streaming] class BlockGenerator(
* `BlockGeneratorListener.onAddData` callback will be called. All received data items
* will be periodically pushed into BlockManager.
*/
- def addDataWithCallback(data: Any, metadata: Any) = synchronized {
+ def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized {
waitToPush()
currentBuffer += data
listener.onAddData(data, metadata)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
index 5acf8a9a811ee..5b5a3fe648602 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
@@ -245,7 +245,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* Get the unique identifier the receiver input stream that this
* receiver is associated with.
*/
- def streamId = id
+ def streamId: Int = id
/*
* =================
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
index 1f0244c251eba..4943f29395d12 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
@@ -162,13 +162,13 @@ private[streaming] abstract class ReceiverSupervisor(
}
/** Check if receiver has been marked for stopping */
- def isReceiverStarted() = {
+ def isReceiverStarted(): Boolean = {
logDebug("state = " + receiverState)
receiverState == Started
}
/** Check if receiver has been marked for stopping */
- def isReceiverStopped() = {
+ def isReceiverStopped(): Boolean = {
logDebug("state = " + receiverState)
receiverState == Stopped
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index 7d29ed88cfcb4..8f2f1fef76874 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Await
-import akka.actor.{Actor, Props}
+import akka.actor.{ActorRef, Actor, Props}
import akka.pattern.ask
import com.google.common.base.Throwables
import org.apache.hadoop.conf.Configuration
@@ -83,7 +83,7 @@ private[streaming] class ReceiverSupervisorImpl(
private val actor = env.actorSystem.actorOf(
Props(new Actor {
- override def receive() = {
+ override def receive: PartialFunction[Any, Unit] = {
case StopReceiver =>
logInfo("Received stop signal")
stop("Stopped by driver", None)
@@ -92,7 +92,7 @@ private[streaming] class ReceiverSupervisorImpl(
cleanupOldBlocks(threshTime)
}
- def ref = self
+ def ref: ActorRef = self
}), "Receiver-" + streamId + "-" + System.currentTimeMillis())
/** Unique block ids if one wants to add blocks directly */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
index 7e0f6b2cdfc08..30cf87f5b7dd1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
@@ -36,5 +36,5 @@ class Job(val time: Time, func: () => _) {
id = "streaming job " + time + "." + number
}
- override def toString = id
+ override def toString: String = id
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 59488dfb0f8c6..4946806d2ee95 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -82,7 +82,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
if (eventActor != null) return // generator has already been started
eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
- def receive = {
+ override def receive: PartialFunction[Any, Unit] = {
case event: JobGeneratorEvent => processEvent(event)
}
}), "JobGenerator")
@@ -111,8 +111,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val pollTime = 100
// To prevent graceful stop to get stuck permanently
- def hasTimedOut = {
- val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout
+ def hasTimedOut: Boolean = {
+ val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeout
if (timedOut) {
logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")")
}
@@ -133,7 +133,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
logInfo("Stopped generation timer")
// Wait for the jobs to complete and checkpoints to be written
- def haveAllBatchesBeenProcessed = {
+ def haveAllBatchesBeenProcessed: Boolean = {
lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime
}
logInfo("Waiting for jobs to be processed and checkpoints to be written")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 60bc099b27a4c..d6a93acbe711b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -56,7 +56,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
logDebug("Starting JobScheduler")
eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
- def receive = {
+ override def receive: PartialFunction[Any, Unit] = {
case event: JobSchedulerEvent => processEvent(event)
}
}), "JobScheduler")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
index 8c15a75b1b0e0..5b134877d0b2d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
@@ -28,8 +28,7 @@ private[streaming]
case class JobSet(
time: Time,
jobs: Seq[Job],
- receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty
- ) {
+ receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty) {
private val incompleteJobs = new HashSet[Job]()
private val submissionTime = System.currentTimeMillis() // when this jobset was submitted
@@ -48,17 +47,17 @@ case class JobSet(
if (hasCompleted) processingEndTime = System.currentTimeMillis()
}
- def hasStarted = processingStartTime > 0
+ def hasStarted: Boolean = processingStartTime > 0
- def hasCompleted = incompleteJobs.isEmpty
+ def hasCompleted: Boolean = incompleteJobs.isEmpty
// Time taken to process all the jobs from the time they started processing
// (i.e. not including the time they wait in the streaming scheduler queue)
- def processingDelay = processingEndTime - processingStartTime
+ def processingDelay: Long = processingEndTime - processingStartTime
// Time taken to process all the jobs from the time they were submitted
// (i.e. including the time they wait in the streaming scheduler queue)
- def totalDelay = {
+ def totalDelay: Long = {
processingEndTime - time.milliseconds
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index b36aeb341d25e..98900473138fe 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -72,7 +72,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
private var actor: ActorRef = null
/** Start the actor and receiver execution thread. */
- def start() = synchronized {
+ def start(): Unit = synchronized {
if (actor != null) {
throw new SparkException("ReceiverTracker already started")
}
@@ -86,7 +86,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
/** Stop the receiver execution thread. */
- def stop(graceful: Boolean) = synchronized {
+ def stop(graceful: Boolean): Unit = synchronized {
if (!receiverInputStreams.isEmpty && actor != null) {
// First, stop the receivers
if (!skipReceiverLaunch) receiverExecutor.stop(graceful)
@@ -201,7 +201,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
/** Actor to receive messages from the receivers. */
private class ReceiverTrackerActor extends Actor {
- def receive = {
+ override def receive: PartialFunction[Any, Unit] = {
case RegisterReceiver(streamId, typ, host, receiverActor) =>
registerReceiver(streamId, typ, host, receiverActor, sender)
sender ! true
@@ -244,16 +244,15 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
if (graceful) {
val pollTime = 100
- def done = { receiverInfo.isEmpty && !running }
logInfo("Waiting for receiver job to terminate gracefully")
- while(!done) {
+ while (receiverInfo.nonEmpty || running) {
Thread.sleep(pollTime)
}
logInfo("Waited for receiver job to terminate gracefully")
}
// Check if all the receivers have been deregistered or not
- if (!receiverInfo.isEmpty) {
+ if (receiverInfo.nonEmpty) {
logWarning("Not all of the receivers have deregistered, " + receiverInfo)
} else {
logInfo("All of the receivers have deregistered successfully")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
index 5ee53a5c5f561..e4bd067cacb77 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
@@ -17,9 +17,10 @@
package org.apache.spark.streaming.ui
+import scala.collection.mutable.{Queue, HashMap}
+
import org.apache.spark.streaming.{Time, StreamingContext}
import org.apache.spark.streaming.scheduler._
-import scala.collection.mutable.{Queue, HashMap}
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted
import org.apache.spark.streaming.scheduler.BatchInfo
@@ -59,11 +60,13 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext)
}
}
- override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) = synchronized {
- runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo
+ override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = {
+ synchronized {
+ runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo
+ }
}
- override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) = synchronized {
+ override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized {
runningBatchInfos(batchStarted.batchInfo.batchTime) = batchStarted.batchInfo
waitingBatchInfos.remove(batchStarted.batchInfo.batchTime)
@@ -72,19 +75,21 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext)
}
}
- override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) = synchronized {
- waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime)
- runningBatchInfos.remove(batchCompleted.batchInfo.batchTime)
- completedaBatchInfos.enqueue(batchCompleted.batchInfo)
- if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue()
- totalCompletedBatches += 1L
-
- batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) =>
- totalProcessedRecords += infos.map(_.numRecords).sum
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = {
+ synchronized {
+ waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime)
+ runningBatchInfos.remove(batchCompleted.batchInfo.batchTime)
+ completedaBatchInfos.enqueue(batchCompleted.batchInfo)
+ if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue()
+ totalCompletedBatches += 1L
+
+ batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) =>
+ totalProcessedRecords += infos.map(_.numRecords).sum
+ }
}
}
- def numReceivers = synchronized {
+ def numReceivers: Int = synchronized {
ssc.graph.getReceiverInputStreams().size
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala
index a73d6f3bf0661..4d968f8bfa7a8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala
@@ -18,9 +18,7 @@
package org.apache.spark.streaming.util
import org.apache.spark.SparkContext
-import org.apache.spark.SparkContext._
import org.apache.spark.util.collection.OpenHashMap
-import scala.collection.JavaConversions.mapAsScalaMap
private[streaming]
object RawTextHelper {
@@ -71,7 +69,7 @@ object RawTextHelper {
var count = 0
while(data.hasNext) {
- value = data.next
+ value = data.next()
if (value != null) {
count += 1
if (len == 0) {
@@ -108,9 +106,13 @@ object RawTextHelper {
}
}
- def add(v1: Long, v2: Long) = (v1 + v2)
+ def add(v1: Long, v2: Long): Long = {
+ v1 + v2
+ }
- def subtract(v1: Long, v2: Long) = (v1 - v2)
+ def subtract(v1: Long, v2: Long): Long = {
+ v1 - v2
+ }
- def max(v1: Long, v2: Long) = math.max(v1, v2)
+ def max(v1: Long, v2: Long): Long = math.max(v1, v2)
}