Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[qob] Read results files 20x faster by not using Scala parallel collections #12854

Merged
merged 2 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 54 additions & 53 deletions hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ import org.newsclub.net.unix.{AFUNIXServerSocket, AFUNIXSocketAddress}

import scala.annotation.switch
import scala.reflect.ClassTag
import scala.{concurrent => scalaConcurrent}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.parallel.ExecutionContextTaskSupport


class ServiceBackendContext(
@transient val sessionID: String,
Expand Down Expand Up @@ -69,9 +67,8 @@ class ServiceBackend(
import ServiceBackend.log

private[this] var stageCount = 0
private[this] implicit val ec = scalaConcurrent.ExecutionContext.fromExecutorService(
Executors.newCachedThreadPool())
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 100
private[this] val executor = Executors.newCachedThreadPool()
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
private[this] val availableGCSConnections = new Semaphore(MAX_AVAILABLE_GCS_CONNECTIONS, true)

override def shouldCacheQueryInfo: Boolean = false
Expand Down Expand Up @@ -117,35 +114,39 @@ class ServiceBackend(
log.info(s"parallelizeAndComputeWithIndex: $token: nPartitions $n")
log.info(s"parallelizeAndComputeWithIndex: $token: writing f and contexts")

val uploadFunction = scalaConcurrent.Future {
retryTransientErrors {
write(s"$root/f") { fos =>
using(new ObjectOutputStream(fos)) { oos => oos.writeObject(f) }
val uploadFunction = executor.submit(new Callable[Unit] {
def call(): Unit = {
retryTransientErrors {
write(s"$root/f") { fos =>
using(new ObjectOutputStream(fos)) { oos => oos.writeObject(f) }
}
}
}
}
})

val uploadContexts = scalaConcurrent.Future {
retryTransientErrors {
write(s"$root/contexts") { os =>
var o = 12L * n
var i = 0
while (i < n) {
val len = collection(i).length
os.writeLong(o)
os.writeInt(len)
i += 1
o += len
}
collection.foreach { context =>
os.write(context)
val uploadContexts = executor.submit(new Callable[Unit] {
def call(): Unit = {
retryTransientErrors {
write(s"$root/contexts") { os =>
var o = 12L * n
var i = 0
while (i < n) {
val len = collection(i).length
os.writeLong(o)
os.writeInt(len)
i += 1
o += len
}
collection.foreach { context =>
os.write(context)
}
}
}
}
}
})

scalaConcurrent.Await.result(uploadFunction, scalaConcurrent.duration.Duration.Inf)
scalaConcurrent.Await.result(uploadContexts, scalaConcurrent.duration.Duration.Inf)
uploadFunction.get()
uploadContexts.get()

val jobs = new Array[JObject](n)
var i = 0
Expand Down Expand Up @@ -212,35 +213,35 @@ class ServiceBackend(

log.info(s"parallelizeAndComputeWithIndex: $token: reading results")

def resultOrHailException(is: DataInputStream): Array[Byte] = {
val success = is.readBoolean()
if (success) {
IOUtils.toByteArray(is)
} else {
val shortMessage = readString(is)
val expandedMessage = readString(is)
val errorId = is.readInt()
throw new HailWorkerException(shortMessage, expandedMessage, errorId)
}
}


val results = Array.range(0, n).par.map { i =>
availableGCSConnections.acquire()
try {
val bytes = retryTransientErrors {
using(open(s"$root/result.$i")) { is =>
resultOrHailException(new DataInputStream(is))
val startTime = System.nanoTime()

val results = executor.invokeAll(IndexedSeq.range(0, n).map { i =>
new Callable[Array[Byte]]() {
def call(): Array[Byte] = {
availableGCSConnections.acquire()
try {
val bytes = fs.readNoCompression(s"$root/result.$i")
if (bytes(0) != 0) {
bytes.slice(1, bytes.length)
} else {
val errorInformationBytes = bytes.slice(1, bytes.length)
val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes))
val shortMessage = readString(is)
val expandedMessage = readString(is)
val errorId = is.readInt()
throw new HailWorkerException(shortMessage, expandedMessage, errorId)
}
} finally {
availableGCSConnections.release()
}
}
log.info(s"result $i complete - ${bytes.length} bytes")
bytes
} finally {
availableGCSConnections.release()
}
}
}.asJava).asScala.map(_.get).toArray

log.info(s"all results complete")
val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0
val rate = results.length / resultsReadingSeconds
val byterate = results.map(_.length).sum / resultsReadingSeconds / 1024 / 1024
log.info(s"all results read. $resultsReadingSeconds s. $rate result/s. $byterate MiB/s.")
results.toArray[Array[Byte]]
}

Expand Down
4 changes: 4 additions & 0 deletions hail/src/main/scala/is/hail/io/fs/FS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ trait FS extends Serializable {
final def openNoCompression(filename: String): SeekableDataInputStream = openNoCompression(filename, false)
def openNoCompression(filename: String, _debug: Boolean): SeekableDataInputStream

def readNoCompression(filename: String): Array[Byte] = retryTransientErrors {
IOUtils.toByteArray(openNoCompression(filename))
}

def createNoCompression(filename: String): PositionedDataOutputStream

def mkDir(dirname: String): Unit = ()
Expand Down
7 changes: 7 additions & 0 deletions hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import java.io.{ByteArrayInputStream, FileNotFoundException, IOException}
import java.net.URI
import java.nio.ByteBuffer
import java.nio.file.FileSystems
import java.util.concurrent._
import org.apache.log4j.Logger
import com.google.auth.oauth2.ServiceAccountCredentials
import com.google.cloud.{ReadChannel, WriteChannel}
Expand All @@ -18,6 +19,7 @@ import is.hail.utils.fatal

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.{concurrent => scalaConcurrent}
import scala.reflect.ClassTag

object GoogleStorageFS {
Expand Down Expand Up @@ -252,6 +254,11 @@ class GoogleStorageFS(
new WrappedSeekableDataInputStream(is)
}

override def readNoCompression(filename: String): Array[Byte] = retryTransientErrors {
val (bucket, path) = getBucketPath(filename)
storage.readAllBytes(bucket, path)
}

def createNoCompression(filename: String): PositionedDataOutputStream = retryTransientErrors {
log.info(f"createNoCompression: ${filename}")
val (bucket, path) = getBucketPath(filename)
Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/io/fs/RouterFS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class RouterFS(schemes: Map[String, FS], default: String) extends FS {

def createNoCompression(filename: String): PositionedDataOutputStream = lookupFS(filename).createNoCompression(filename)

override def readNoCompression(filename: String): Array[Byte] = lookupFS(filename).readNoCompression(filename)

override def mkDir(dirname: String): Unit = lookupFS(dirname).mkDir(dirname)

def delete(filename: String, recursive: Boolean) = lookupFS(filename).delete(filename, recursive)
Expand Down