diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 9e109ec16ed..d352612732c 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -429,8 +429,8 @@ object ServiceBackendSocketAPI2 { val kind = argv(3) assert(kind == Main.DRIVER) val name = argv(4) - val input = argv(5) - val output = argv(6) + val inputURL = argv(5) + val outputURL = argv(6) val fs = FS.cloudSpecificCacheableFS(s"$scratchDir/secrets/gsa-key/key.json", None) val deployConfig = DeployConfig.fromConfigFile( @@ -455,39 +455,16 @@ object ServiceBackendSocketAPI2 { HailContext(backend, 50, 3) } - retryTransientErrors { - using(fs.openNoCompression(input)) { in => - retryTransientErrors { - using(fs.createNoCompression(output)) { out => - new ServiceBackendSocketAPI2(backend, in, out, sessionId).executeOneCommand() - out.flush() - } - } - } - } + new ServiceBackendSocketAPI2(backend, fs, inputURL, outputURL, sessionId).executeOneCommand() } } -class ServiceBackendSocketAPI2( - private[this] val backend: ServiceBackend, - private[this] val in: InputStream, - private[this] val out: OutputStream, - private[this] val sessionId: String, -) extends Thread { - private[this] val LOAD_REFERENCES_FROM_DATASET = 1 - private[this] val VALUE_TYPE = 2 - private[this] val TABLE_TYPE = 3 - private[this] val MATRIX_TABLE_TYPE = 4 - private[this] val BLOCK_MATRIX_TYPE = 5 - private[this] val EXECUTE = 6 - private[this] val PARSE_VCF_METADATA = 7 - private[this] val IMPORT_FAM = 8 - private[this] val FROM_FASTA_FILE = 9 - +private class HailSocketAPIInputStream( + private[this] val in: InputStream +) extends AutoCloseable { + private[this] var closed: Boolean = false private[this] val dummy = new Array[Byte](8) - private[this] val log = Logger.getLogger(getClass.getName()) - def read(bytes: Array[Byte], off: Int, n: Int): Unit = { assert(off + n <= bytes.length) var read = 0 @@ -536,6 +513,20 @@ class ServiceBackendSocketAPI2( arr } + def close(): Unit = { + if (!closed) { + in.close() + closed = true + } + } +} + +private class HailSocketAPIOutputStream( + private[this] val out: OutputStream +) extends AutoCloseable { + private[this] var closed: Boolean = false + private[this] val dummy = new Array[Byte](8) + def writeBool(b: Boolean): Unit = { out.write(if (b) 1 else 0) } @@ -557,175 +548,211 @@ class ServiceBackendSocketAPI2( def writeString(s: String): Unit = writeBytes(s.getBytes(StandardCharsets.UTF_8)) - def executeOneCommand(): Unit = { - var nFlagsRemaining = readInt() - val flagsMap = mutable.Map[String, String]() - while (nFlagsRemaining > 0) { - val flagName = readString() - val flagValue = readString() - flagsMap.update(flagName, flagValue) - nFlagsRemaining -= 1 - } - val nCustomReferences = readInt() - var i = 0 - while (i < nCustomReferences) { - backend.addReference(ReferenceGenome.fromJSON(readString())) - i += 1 + def close(): Unit = { + if (!closed) { + out.close() + closed = true } - val nLiftoverSourceGenomes = readInt() - val liftovers = mutable.Map[String, mutable.Map[String, String]]() - i = 0 - while (i < nLiftoverSourceGenomes) { - val sourceGenome = readString() - val nLiftovers = readInt() - liftovers(sourceGenome) = mutable.Map[String, String]() - var j = 0 - while (j < nLiftovers) { - val destGenome = readString() - val chainFile = readString() - liftovers(sourceGenome)(destGenome) = chainFile - j += 1 + } +} + +class ServiceBackendSocketAPI2( + private[this] val backend: ServiceBackend, + private[this] val fs: FS, + private[this] val inputURL: String, + private[this] val outputURL: String, + private[this] val sessionId: String, +) extends Thread { + private[this] val LOAD_REFERENCES_FROM_DATASET = 1 + private[this] val VALUE_TYPE = 2 + private[this] val TABLE_TYPE = 3 + private[this] val MATRIX_TABLE_TYPE = 4 + private[this] val BLOCK_MATRIX_TYPE = 5 + private[this] val EXECUTE = 6 + private[this] val PARSE_VCF_METADATA = 7 + private[this] val IMPORT_FAM = 8 + private[this] val FROM_FASTA_FILE = 9 + + private[this] val log = Logger.getLogger(getClass.getName()) + + private[this] def parseInputToCommandThunk(): () => Array[Byte] = retryTransientErrors { + using(fs.openNoCompression(inputURL)) { inputStream => + val input = new HailSocketAPIInputStream(inputStream) + + var nFlagsRemaining = input.readInt() + val flagsMap = mutable.Map[String, String]() + while (nFlagsRemaining > 0) { + val flagName = input.readString() + val flagValue = input.readString() + flagsMap.update(flagName, flagValue) + nFlagsRemaining -= 1 } - i += 1 - } - val nAddedSequences = readInt() - val addedSequences = mutable.Map[String, (String, String)]() - i = 0 - while (i < nAddedSequences) { - val rgName = readString() - val fastaFile = readString() - val indexFile = readString() - addedSequences(rgName) = (fastaFile, indexFile) - i += 1 - } - val workerCores = readString() - val workerMemory = readString() - - var nRegions = readInt() - val regions = { - val regionsArrayBuffer = mutable.ArrayBuffer[String]() - while (nRegions > 0) { - val region = readString() - regionsArrayBuffer += region - nRegions -= 1 + val nCustomReferences = input.readInt() + var i = 0 + while (i < nCustomReferences) { + backend.addReference(ReferenceGenome.fromJSON(input.readString())) + i += 1 + } + val nLiftoverSourceGenomes = input.readInt() + val liftovers = mutable.Map[String, mutable.Map[String, String]]() + i = 0 + while (i < nLiftoverSourceGenomes) { + val sourceGenome = input.readString() + val nLiftovers = input.readInt() + liftovers(sourceGenome) = mutable.Map[String, String]() + var j = 0 + while (j < nLiftovers) { + val destGenome = input.readString() + val chainFile = input.readString() + liftovers(sourceGenome)(destGenome) = chainFile + j += 1 + } + i += 1 + } + val nAddedSequences = input.readInt() + val addedSequences = mutable.Map[String, (String, String)]() + i = 0 + while (i < nAddedSequences) { + val rgName = input.readString() + val fastaFile = input.readString() + val indexFile = input.readString() + addedSequences(rgName) = (fastaFile, indexFile) + i += 1 + } + val workerCores = input.readString() + val workerMemory = input.readString() + + var nRegions = input.readInt() + val regions = { + val regionsArrayBuffer = mutable.ArrayBuffer[String]() + while (nRegions > 0) { + val region = input.readString() + regionsArrayBuffer += region + nRegions -= 1 + } + regionsArrayBuffer.toArray } - regionsArrayBuffer.toArray - } - - val storageRequirement = readString() - val nCloudfuseConfigElements = readInt() - val cloudfuseConfig = new Array[(String, String, Boolean)](nCloudfuseConfigElements) - i = 0 - while (i < nCloudfuseConfigElements) { - val bucket = readString() - val mountPoint = readString() - val readonly = readBool() - cloudfuseConfig(i) = (bucket, mountPoint, readonly) - i += 1 - } + val storageRequirement = input.readString() + val nCloudfuseConfigElements = input.readInt() + val cloudfuseConfig = new Array[(String, String, Boolean)](nCloudfuseConfigElements) + i = 0 + while (i < nCloudfuseConfigElements) { + val bucket = input.readString() + val mountPoint = input.readString() + val readonly = input.readBool() + cloudfuseConfig(i) = (bucket, mountPoint, readonly) + i += 1 + } - val cmd = readInt() - - val tmpdir = readString() - val billingProject = readString() - val remoteTmpDir = readString() - - def withExecuteContext(methodName: String, method: ExecuteContext => Array[Byte]): Array[Byte] = ExecutionTimer.logTime(methodName) { timer => - val flags = HailFeatureFlags.fromMap(flagsMap) - val shouldProfile = flags.get("profile") != null - val fs = FS.cloudSpecificCacheableFS(s"${backend.scratchDir}/secrets/gsa-key/key.json", Some(flags)) - ExecuteContext.scoped( - tmpdir, - "file:///tmp", - backend, - fs, - timer, - null, - backend.theHailClassLoader, - backend.references, - flags - ) { ctx => - liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - ctx.getReference(sourceGenome).addLiftover(ctx, chainFile, destGenome) + val cmd = input.readInt() + + val tmpdir = input.readString() + val billingProject = input.readString() + val remoteTmpDir = input.readString() + + def withExecuteContext( + methodName: String, + method: ExecuteContext => Array[Byte] + ): () => Array[Byte] = { + val flags = HailFeatureFlags.fromMap(flagsMap) + val shouldProfile = flags.get("profile") != null + val fs = FS.cloudSpecificCacheableFS(s"${backend.scratchDir}/secrets/gsa-key/key.json", Some(flags)) + + { () => + ExecutionTimer.logTime(methodName) { timer => + ExecuteContext.scoped( + tmpdir, + "file:///tmp", + backend, + fs, + timer, + null, + backend.theHailClassLoader, + backend.references, + flags + ) { ctx => + liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + ctx.getReference(sourceGenome).addLiftover(ctx, chainFile, destGenome) + } + } + addedSequences.foreach { case (rg, (fastaFile, indexFile)) => + ctx.getReference(rg).addSequence(ctx, fastaFile, indexFile) + } + ctx.backendContext = new ServiceBackendContext(sessionId, billingProject, remoteTmpDir, workerCores, workerMemory, storageRequirement, regions, cloudfuseConfig, shouldProfile) + method(ctx) + } } } - addedSequences.foreach { case (rg, (fastaFile, indexFile)) => - ctx.getReference(rg).addSequence(ctx, fastaFile, indexFile) - } - ctx.backendContext = new ServiceBackendContext(sessionId, billingProject, remoteTmpDir, workerCores, workerMemory, storageRequirement, regions, cloudfuseConfig, shouldProfile) - method(ctx) } - } - try { - val result = (cmd: @switch) match { + (cmd: @switch) match { case LOAD_REFERENCES_FROM_DATASET => - val path = readString() + val path = input.readString() withExecuteContext( "ServiceBackend.loadReferencesFromDataset", backend.loadReferencesFromDataset(_, path).getBytes(StandardCharsets.UTF_8) ) case VALUE_TYPE => - val s = readString() + val s = input.readString() withExecuteContext( "ServiceBackend.valueType", backend.valueType(_, s).getBytes(StandardCharsets.UTF_8) ) case TABLE_TYPE => - val s = readString() + val s = input.readString() withExecuteContext( "ServiceBackend.tableType", backend.tableType(_, s).getBytes(StandardCharsets.UTF_8) ) case MATRIX_TABLE_TYPE => - val s = readString() + val s = input.readString() withExecuteContext( "ServiceBackend.matrixTableType", backend.matrixTableType(_, s).getBytes(StandardCharsets.UTF_8) ) case BLOCK_MATRIX_TYPE => - val s = readString() + val s = input.readString() withExecuteContext( "ServiceBackend.blockMatrixType", backend.blockMatrixType(_, s).getBytes(StandardCharsets.UTF_8) ) case EXECUTE => - val code = readString() - val token = readString() + val code = input.readString() + val token = input.readString() withExecuteContext( "ServiceBackend.execute", { ctx => - withIRFunctionsReadFromInput(ctx) { () => - val bufferSpecString = readString() + withIRFunctionsReadFromInput(input, ctx) { () => + val bufferSpecString = input.readString() backend.execute(ctx, code, token, bufferSpecString) } } ) case PARSE_VCF_METADATA => - val path = readString() + val path = input.readString() withExecuteContext( "ServiceBackend.parseVCFMetadata", backend.parseVCFMetadata(_, path).getBytes(StandardCharsets.UTF_8) ) case IMPORT_FAM => - val path = readString() - val quantPheno = readBool() - val delimiter = readString() - val missing = readString() + val path = input.readString() + val quantPheno = input.readBool() + val delimiter = input.readString() + val missing = input.readString() withExecuteContext( "ServiceBackend.importFam", backend.importFam(_, path, quantPheno, delimiter, missing).getBytes(StandardCharsets.UTF_8) ) case FROM_FASTA_FILE => - val name = readString() - val fastaFile = readString() - val indexFile = readString() - val xContigs = readStringArray() - val yContigs = readStringArray() - val mtContigs = readStringArray() - val parInput = readStringArray() + val name = input.readString() + val fastaFile = input.readString() + val indexFile = input.readString() + val xContigs = input.readStringArray() + val yContigs = input.readStringArray() + val mtContigs = input.readStringArray() + val parInput = input.readStringArray() withExecuteContext( "ServiceBackend.fromFASTAFile", backend.fromFASTAFile( @@ -740,56 +767,47 @@ class ServiceBackendSocketAPI2( ).getBytes(StandardCharsets.UTF_8) ) } - writeBool(true) - writeBytes(result) - } catch { - case exc: HailWorkerException => - writeBool(false) - writeString(exc.shortMessage) - writeString(exc.expandedMessage) - writeInt(exc.errorId) - case t: Throwable => - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - writeBool(false) - writeString(shortMessage) - writeString(expandedMessage) - writeInt(errorId) } } - def withIRFunctionsReadFromInput(ctx: ExecuteContext)(body: () => Array[Byte]): Array[Byte] = { + private[this] def withIRFunctionsReadFromInput( + input: HailSocketAPIInputStream, + ctx: ExecuteContext + )( + body: () => Array[Byte] + ): Array[Byte] = { try { - var nFunctionsRemaining = readInt() + var nFunctionsRemaining = input.readInt() while (nFunctionsRemaining > 0) { - val name = readString() + val name = input.readString() - val nTypeParametersRemaining = readInt() + val nTypeParametersRemaining = input.readInt() val typeParameters = new Array[String](nTypeParametersRemaining) var i = 0 while (i < nTypeParametersRemaining) { - typeParameters(i) = readString() + typeParameters(i) = input.readString() i += 1 } - val nValueParameterNamesRemaining = readInt() + val nValueParameterNamesRemaining = input.readInt() val valueParameterNames = new Array[String](nValueParameterNamesRemaining) i = 0 while (i < nValueParameterNamesRemaining) { - valueParameterNames(i) = readString() + valueParameterNames(i) = input.readString() i += 1 } - val nValueParameterTypesRemaining = readInt() + val nValueParameterTypesRemaining = input.readInt() val valueParameterTypes = new Array[String](nValueParameterTypesRemaining) i = 0 while (i < nValueParameterTypesRemaining) { - valueParameterTypes(i) = readString() + valueParameterTypes(i) = input.readString() i += 1 } - val returnType = readString() + val returnType = input.readString() - val renderedBody = readString() + val renderedBody = input.readString() IRFunctionRegistry.pyRegisterIRForServiceBackend( ctx, @@ -807,4 +825,41 @@ class ServiceBackendSocketAPI2( IRFunctionRegistry.clearUserFunctions() } } + + def executeOneCommand(): Unit = { + val commandThunk = parseInputToCommandThunk() + + try { + val result = commandThunk() + retryTransientErrors { + using(fs.createNoCompression(outputURL)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(true) + output.writeBytes(result) + } + } + } catch { + case exc: HailWorkerException => + retryTransientErrors { + using(fs.createNoCompression(outputURL)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(false) + output.writeString(exc.shortMessage) + output.writeString(exc.expandedMessage) + output.writeInt(exc.errorId) + } + } + case t: Throwable => + val (shortMessage, expandedMessage, errorId) = handleForPython(t) + retryTransientErrors { + using(fs.createNoCompression(outputURL)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(false) + output.writeString(shortMessage) + output.writeString(expandedMessage) + output.writeInt(errorId) + } + } + } + } }