diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index d891fd4c5b7..45e06127c7c 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -37,10 +37,8 @@ import org.newsclub.net.unix.{AFUNIXServerSocket, AFUNIXSocketAddress} import scala.annotation.switch import scala.reflect.ClassTag -import scala.{concurrent => scalaConcurrent} +import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.parallel.ExecutionContextTaskSupport - class ServiceBackendContext( @transient val sessionID: String, @@ -69,9 +67,8 @@ class ServiceBackend( import ServiceBackend.log private[this] var stageCount = 0 - private[this] implicit val ec = scalaConcurrent.ExecutionContext.fromExecutorService( - Executors.newCachedThreadPool()) - private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 100 + private[this] val executor = Executors.newCachedThreadPool() + private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000 private[this] val availableGCSConnections = new Semaphore(MAX_AVAILABLE_GCS_CONNECTIONS, true) override def shouldCacheQueryInfo: Boolean = false @@ -117,35 +114,39 @@ class ServiceBackend( log.info(s"parallelizeAndComputeWithIndex: $token: nPartitions $n") log.info(s"parallelizeAndComputeWithIndex: $token: writing f and contexts") - val uploadFunction = scalaConcurrent.Future { - retryTransientErrors { - write(s"$root/f") { fos => - using(new ObjectOutputStream(fos)) { oos => oos.writeObject(f) } + val uploadFunction = executor.submit(new Callable[Unit] { + def call(): Unit = { + retryTransientErrors { + write(s"$root/f") { fos => + using(new ObjectOutputStream(fos)) { oos => oos.writeObject(f) } + } } } - } + }) - val uploadContexts = scalaConcurrent.Future { - retryTransientErrors { - write(s"$root/contexts") { os => - var o = 12L * n - var i = 0 - while (i < n) { - val len = collection(i).length - os.writeLong(o) - os.writeInt(len) - i += 1 - o += len - } - collection.foreach { context => - os.write(context) + val uploadContexts = executor.submit(new Callable[Unit] { + def call(): Unit = { + retryTransientErrors { + write(s"$root/contexts") { os => + var o = 12L * n + var i = 0 + while (i < n) { + val len = collection(i).length + os.writeLong(o) + os.writeInt(len) + i += 1 + o += len + } + collection.foreach { context => + os.write(context) + } } } } - } + }) - scalaConcurrent.Await.result(uploadFunction, scalaConcurrent.duration.Duration.Inf) - scalaConcurrent.Await.result(uploadContexts, scalaConcurrent.duration.Duration.Inf) + uploadFunction.get() + uploadContexts.get() val jobs = new Array[JObject](n) var i = 0 @@ -212,35 +213,39 @@ class ServiceBackend( log.info(s"parallelizeAndComputeWithIndex: $token: reading results") - def resultOrHailException(is: DataInputStream): Array[Byte] = { - val success = is.readBoolean() - if (success) { - IOUtils.toByteArray(is) - } else { - val shortMessage = readString(is) - val expandedMessage = readString(is) - val errorId = is.readInt() - throw new HailWorkerException(shortMessage, expandedMessage, errorId) - } - } - - - val results = Array.range(0, n).par.map { i => - availableGCSConnections.acquire() - try { - val bytes = retryTransientErrors { - using(open(s"$root/result.$i")) { is => - resultOrHailException(new DataInputStream(is)) + val startTime = System.nanoTime() + + val results = try { + executor.invokeAll(IndexedSeq.range(0, n).map { i => + new Callable[Array[Byte]]() { + def call(): Array[Byte] = { + availableGCSConnections.acquire() + try { + val bytes = fs.readNoCompression(s"$root/result.$i") + if (bytes(0) != 0) { + bytes.slice(1, bytes.length) + } else { + val errorInformationBytes = bytes.slice(1, bytes.length) + val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes)) + val shortMessage = readString(is) + val expandedMessage = readString(is) + val errorId = is.readInt() + throw new HailWorkerException(shortMessage, expandedMessage, errorId) + } + } finally { + availableGCSConnections.release() + } } } - log.info(s"result $i complete - ${bytes.length} bytes") - bytes - } finally { - availableGCSConnections.release() - } + }.asJava).asScala.map(_.get).toArray + } catch { + case exc: ExecutionException if exc.getCause() != null => throw exc.getCause() } - log.info(s"all results complete") + val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0 + val rate = results.length / resultsReadingSeconds + val byterate = results.map(_.length).sum / resultsReadingSeconds / 1024 / 1024 + log.info(s"all results read. $resultsReadingSeconds s. $rate result/s. $byterate MiB/s.") results.toArray[Array[Byte]] } diff --git a/hail/src/main/scala/is/hail/io/fs/FS.scala b/hail/src/main/scala/is/hail/io/fs/FS.scala index 5e43694c81c..049b768cf5a 100644 --- a/hail/src/main/scala/is/hail/io/fs/FS.scala +++ b/hail/src/main/scala/is/hail/io/fs/FS.scala @@ -319,6 +319,10 @@ trait FS extends Serializable { final def openNoCompression(filename: String): SeekableDataInputStream = openNoCompression(filename, false) def openNoCompression(filename: String, _debug: Boolean): SeekableDataInputStream + def readNoCompression(filename: String): Array[Byte] = retryTransientErrors { + IOUtils.toByteArray(openNoCompression(filename)) + } + def createNoCompression(filename: String): PositionedDataOutputStream def mkDir(dirname: String): Unit = () diff --git a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala index b9982c4e524..c24685a6596 100644 --- a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala @@ -5,6 +5,7 @@ import java.io.{ByteArrayInputStream, FileNotFoundException, IOException} import java.net.URI import java.nio.ByteBuffer import java.nio.file.FileSystems +import java.util.concurrent._ import org.apache.log4j.Logger import com.google.auth.oauth2.ServiceAccountCredentials import com.google.cloud.{ReadChannel, WriteChannel} @@ -18,6 +19,7 @@ import is.hail.utils.fatal import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.{concurrent => scalaConcurrent} import scala.reflect.ClassTag object GoogleStorageFS { @@ -252,6 +254,11 @@ class GoogleStorageFS( new WrappedSeekableDataInputStream(is) } + override def readNoCompression(filename: String): Array[Byte] = retryTransientErrors { + val (bucket, path) = getBucketPath(filename) + storage.readAllBytes(bucket, path) + } + def createNoCompression(filename: String): PositionedDataOutputStream = retryTransientErrors { log.info(f"createNoCompression: ${filename}") val (bucket, path) = getBucketPath(filename) diff --git a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala index 58e86131cf9..496fa4bd9dc 100644 --- a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala @@ -15,6 +15,8 @@ class RouterFS(schemes: Map[String, FS], default: String) extends FS { def createNoCompression(filename: String): PositionedDataOutputStream = lookupFS(filename).createNoCompression(filename) + override def readNoCompression(filename: String): Array[Byte] = lookupFS(filename).readNoCompression(filename) + override def mkDir(dirname: String): Unit = lookupFS(dirname).mkDir(dirname) def delete(filename: String, recursive: Boolean) = lookupFS(filename).delete(filename, recursive)