Skip to content

Commit

Permalink
[compiler] apply scalafmt to all scala sources (hail-is#14129)
Browse files Browse the repository at this point in the history
Run scalafmt on compiler scala sources.
Split from hail-is#14103.
Co-authored by @patrick-schultz.
  • Loading branch information
ehigham authored Jan 8, 2024
1 parent 5606d94 commit 422edf6
Show file tree
Hide file tree
Showing 641 changed files with 58,696 additions and 29,702 deletions.
61 changes: 37 additions & 24 deletions hail/src/main/scala/is/hail/HailContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ import is.hail.io.fs.FS
import is.hail.io.vcf._
import is.hail.types.virtual._
import is.hail.utils._
import org.apache.log4j.{ConsoleAppender, LogManager, PatternLayout, PropertyConfigurator}
import org.apache.spark._
import org.apache.spark.executor.InputMetrics
import org.apache.spark.rdd.RDD

import org.json4s.Extraction
import org.json4s.jackson.JsonMethods

Expand All @@ -20,6 +17,11 @@ import java.util.Properties
import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.log4j.{ConsoleAppender, LogManager, PatternLayout, PropertyConfigurator}
import org.apache.spark._
import org.apache.spark.executor.InputMetrics
import org.apache.spark.rdd.RDD

case class FilePartition(index: Int, file: String) extends Partition

object HailContext {
Expand Down Expand Up @@ -78,7 +80,8 @@ object HailContext {
versionString match {
// old-style version: 1.MAJOR.MINOR
// new-style version: MAJOR.MINOR.SECURITY (started in JRE 9)
// see: https://docs.oracle.com/javase/9/migrate/toc.htm#JSMIG-GUID-3A71ECEF-5FC5-46FE-9BA9-88CBFCE828CB
/* see:
* https://docs.oracle.com/javase/9/migrate/toc.htm#JSMIG-GUID-3A71ECEF-5FC5-46FE-9BA9-88CBFCE828CB */
case javaVersion("1", major, minor) =>
if (major.toInt < 8)
fatal(s"Hail requires Java 1.8, found $versionString")
Expand All @@ -90,38 +93,46 @@ object HailContext {
}
}

def getOrCreate(backend: Backend,
branchingFactor: Int = 50,
optimizerIterations: Int = 3): HailContext = {
def getOrCreate(backend: Backend, branchingFactor: Int = 50, optimizerIterations: Int = 3)
: HailContext = {
if (theContext == null)
return HailContext(backend, branchingFactor, optimizerIterations)

if (theContext.branchingFactor != branchingFactor)
warn(s"Requested branchingFactor $branchingFactor, but already initialized to ${ theContext.branchingFactor }. Ignoring requested setting.")
warn(
s"Requested branchingFactor $branchingFactor, but already initialized to ${theContext.branchingFactor}. Ignoring requested setting."
)

if (theContext.optimizerIterations != optimizerIterations)
warn(s"Requested optimizerIterations $optimizerIterations, but already initialized to ${ theContext.optimizerIterations }. Ignoring requested setting.")
warn(
s"Requested optimizerIterations $optimizerIterations, but already initialized to ${theContext.optimizerIterations}. Ignoring requested setting."
)

theContext
}

def apply(backend: Backend,
branchingFactor: Int = 50,
optimizerIterations: Int = 3): HailContext = synchronized {
def apply(backend: Backend, branchingFactor: Int = 50, optimizerIterations: Int = 3)
: HailContext = synchronized {
require(theContext == null)
checkJavaVersion()

{
import breeze.linalg._
import breeze.linalg.operators.{BinaryRegistry, OpMulMatrix}

implicitly[BinaryRegistry[DenseMatrix[Double], Vector[Double], OpMulMatrix.type, DenseVector[Double]]].register(
DenseMatrix.implOpMulMatrix_DMD_DVD_eq_DVD)
implicitly[BinaryRegistry[
DenseMatrix[Double],
Vector[Double],
OpMulMatrix.type,
DenseVector[Double],
]].register(
DenseMatrix.implOpMulMatrix_DMD_DVD_eq_DVD
)
}

theContext = new HailContext(backend, branchingFactor, optimizerIterations)

info(s"Running Hail version ${ theContext.version }")
info(s"Running Hail version ${theContext.version}")

theContext
}
Expand All @@ -138,7 +149,8 @@ object HailContext {
path: String,
partFiles: IndexedSeq[String],
read: (Int, InputStream, InputMetrics) => Iterator[T],
optPartitioner: Option[Partitioner] = None): RDD[T] = {
optPartitioner: Option[Partitioner] = None,
): RDD[T] = {
val nPartitions = partFiles.length

val fsBc = fs.broadcast
Expand All @@ -159,10 +171,11 @@ object HailContext {
}
}

class HailContext private(
class HailContext private (
var backend: Backend,
val branchingFactor: Int,
val optimizerIterations: Int) {
val optimizerIterations: Int,
) {
def stop(): Unit = HailContext.stop()

def sparkBackend(op: String): SparkBackend = backend.asSpark(op)
Expand All @@ -175,7 +188,7 @@ class HailContext private(
fs: FS,
regex: String,
files: Seq[String],
maxLines: Int
maxLines: Int,
): Map[String, Array[WithContext[String]]] = {
val regexp = regex.r
SparkBackend.sparkContext("fileAndLineCounts").textFilesLines(fs.globAll(files).map(_.getPath))
Expand All @@ -186,7 +199,7 @@ class HailContext private(

def grepPrint(fs: FS, regex: String, files: Seq[String], maxLines: Int) {
fileAndLineCounts(fs, regex, files, maxLines).foreach { case (file, lines) =>
info(s"$file: ${ lines.length } ${ plural(lines.length, "match", "matches") }:")
info(s"$file: ${lines.length} ${plural(lines.length, "match", "matches")}:")
lines.map(_.value).foreach { line =>
val (screen, logged) = line.truncatable().strings
log.info("\t" + logged)
Expand All @@ -195,12 +208,12 @@ class HailContext private(
}
}

def grepReturn(fs: FS, regex: String, files: Seq[String], maxLines: Int): Array[(String, Array[String])] =
def grepReturn(fs: FS, regex: String, files: Seq[String], maxLines: Int)
: Array[(String, Array[String])] =
fileAndLineCounts(fs: FS, regex, files, maxLines).mapValues(_.map(_.value)).toArray

def parseVCFMetadata(fs: FS, file: String): Map[String, Map[String, Map[String, String]]] = {
def parseVCFMetadata(fs: FS, file: String): Map[String, Map[String, Map[String, String]]] =
LoadVCF.parseHeaderMetadata(fs, Set.empty, TFloat64, file)
}

def pyParseVCFMetadataJSON(fs: FS, file: String): String = {
val metadata = LoadVCF.parseHeaderMetadata(fs, Set.empty, TFloat64, file)
Expand Down
6 changes: 4 additions & 2 deletions hail/src/main/scala/is/hail/HailFeatureFlags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package is.hail

import is.hail.backend.ExecutionCache
import is.hail.utils._

import org.json4s.JsonAST.{JArray, JObject, JString}

import scala.collection.mutable
Expand Down Expand Up @@ -75,9 +76,10 @@ class HailFeatureFlags private (
def toJSONEnv: JArray =
JArray(flags.filter { case (_, v) =>
v != null
}.map{ case (name, v) =>
}.map { case (name, v) =>
JObject(
"name" -> JString(HailFeatureFlags.defaults(name)._1),
"value" -> JString(v))
"value" -> JString(v),
)
}.toList)
}
14 changes: 11 additions & 3 deletions hail/src/main/scala/is/hail/annotations/Annotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package is.hail.annotations

import is.hail.types.virtual._
import is.hail.utils._

import org.apache.spark.sql.Row

object Annotation {
Expand Down Expand Up @@ -33,14 +34,21 @@ object Annotation {

case t: TInterval =>
val i = a.asInstanceOf[Interval]
i.copy(start = Annotation.copy(t.pointType, i.start), end = Annotation.copy(t.pointType, i.end))
i.copy(
start = Annotation.copy(t.pointType, i.start),
end = Annotation.copy(t.pointType, i.end),
)

case t: TNDArray =>
val nd = a.asInstanceOf[NDArray]
val rme = nd.getRowMajorElements()
SafeNDArray(nd.shape, Array.tabulate(rme.length)(i => Annotation.copy(t.elementType, rme(i))).toFastSeq)
SafeNDArray(
nd.shape,
Array.tabulate(rme.length)(i => Annotation.copy(t.elementType, rme(i))).toFastSeq,
)

case TInt32 | TInt64 | TFloat32 | TFloat64 | TBoolean | TString | TCall | _: TLocus | TBinary => a
case TInt32 | TInt64 | TFloat32 | TFloat64 | TBoolean | TString | TCall | _: TLocus | TBinary =>
a
}
}
}
53 changes: 33 additions & 20 deletions hail/src/main/scala/is/hail/annotations/BroadcastValue.scala
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
package is.hail.annotations

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
import is.hail.asm4s.HailClassLoader
import is.hail.backend.{BroadcastValue, ExecuteContext}
import is.hail.expr.ir.EncodedLiteral
import is.hail.io.{BufferSpec, Decoder, TypedCodecSpec}
import is.hail.types.physical.{PArray, PStruct, PType}
import is.hail.types.virtual.{TBaseStruct, TStruct}
import is.hail.io.{BufferSpec, Decoder, TypedCodecSpec}
import is.hail.utils.{ArrayOfByteArrayOutputStream, formatSpace, log}
import is.hail.utils.{formatSpace, log, ArrayOfByteArrayOutputStream}
import is.hail.utils.prettyPrint.ArrayOfByteArrayInputStream

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}

import org.apache.spark.sql.Row

case class SerializableRegionValue(
encodedValue: Array[Array[Byte]], t: PType,
makeDecoder: (InputStream, HailClassLoader) => Decoder
encodedValue: Array[Array[Byte]],
t: PType,
makeDecoder: (InputStream, HailClassLoader) => Decoder,
) {
def readRegionValue(r: Region, theHailClassLoader: HailClassLoader): Long = {
val dec = makeDecoder(new ArrayOfByteArrayInputStream(encodedValue), theHailClassLoader)
Expand Down Expand Up @@ -69,7 +72,9 @@ trait BroadcastRegionValue {
if (broadcasted == null) {
val arrays = encodeToByteArrays(theHailClassLoader)
val totalSize = arrays.map(_.length).sum
log.info(s"BroadcastRegionValue.broadcast: broadcasting ${ arrays.length } byte arrays of total size $totalSize (${ formatSpace(totalSize) }")
log.info(
s"BroadcastRegionValue.broadcast: broadcasting ${arrays.length} byte arrays of total size $totalSize (${formatSpace(totalSize)}"
)
val srv = SerializableRegionValue(arrays, decodedPType, makeDec)
broadcasted = ctx.backend.broadcast(srv)
}
Expand All @@ -83,17 +88,16 @@ trait BroadcastRegionValue {
def safeJavaValue: Any

override def equals(obj: Any): Boolean = obj match {
case b: BroadcastRegionValue => t == b.t && (ctx eq b.ctx) && t.unsafeOrdering(ctx.stateManager).compare(value, b.value) == 0
case b: BroadcastRegionValue =>
t == b.t && (ctx eq b.ctx) && t.unsafeOrdering(ctx.stateManager).compare(value, b.value) == 0
case _ => false
}

override def hashCode(): Int = javaValue.hashCode()
}

case class BroadcastRow(ctx: ExecuteContext,
value: RegionValue,
t: PStruct
) extends BroadcastRegionValue {
case class BroadcastRow(ctx: ExecuteContext, value: RegionValue, t: PStruct)
extends BroadcastRegionValue {

def javaValue: UnsafeRow = UnsafeRow.readBaseStruct(t, value.region, value.offset)

Expand All @@ -104,20 +108,24 @@ case class BroadcastRow(ctx: ExecuteContext,
if (t == newT)
return this

BroadcastRow(ctx,
RegionValue(value.region, newT.copyFromAddress(ctx.stateManager, value.region, t, value.offset, deepCopy = false)),
newT)
BroadcastRow(
ctx,
RegionValue(
value.region,
newT.copyFromAddress(ctx.stateManager, value.region, t, value.offset, deepCopy = false),
),
newT,
)
}

def toEncodedLiteral(theHailClassLoader: HailClassLoader): EncodedLiteral = {
def toEncodedLiteral(theHailClassLoader: HailClassLoader): EncodedLiteral =
EncodedLiteral(encoding, encodeToByteArrays(theHailClassLoader))
}
}

case class BroadcastIndexedSeq(
ctx: ExecuteContext,
value: RegionValue,
t: PArray
t: PArray,
) extends BroadcastRegionValue {

def safeJavaValue: IndexedSeq[Row] = SafeRow.read(t, value).asInstanceOf[IndexedSeq[Row]]
Expand All @@ -129,8 +137,13 @@ case class BroadcastIndexedSeq(
if (t == newT)
return this

BroadcastIndexedSeq(ctx,
RegionValue(value.region, newT.copyFromAddress(ctx.stateManager, value.region, t, value.offset, deepCopy = false)),
newT)
BroadcastIndexedSeq(
ctx,
RegionValue(
value.region,
newT.copyFromAddress(ctx.stateManager, value.region, t, value.offset, deepCopy = false),
),
newT,
)
}
}
Loading

0 comments on commit 422edf6

Please sign in to comment.