Skip to content

Commit

Permalink
Implement client-side run
Browse files Browse the repository at this point in the history
**Problem**
`run` task blocks the server, but during the run the server is just
waiting for the built program to finish.

**Solution**
This implements client-side run where the server creates a sandbox
environment, and sends the information to the client,
and the client forks a new JVM to perform the run.
  • Loading branch information
eed3si9n committed Mar 3, 2025
1 parent 003e549 commit 6726563
Show file tree
Hide file tree
Showing 41 changed files with 925 additions and 139 deletions.
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ lazy val protocolProj = (project in file("protocol"))
// General command support and core commands not specific to a build system
lazy val commandProj = (project in file("main-command"))
.enablePlugins(ContrabandPlugin, JsonCodecPlugin)
.dependsOn(protocolProj, completeProj, utilLogging)
.dependsOn(protocolProj, completeProj, utilLogging, runProj)
.settings(
testedBaseSettings,
name := "Command",
Expand Down Expand Up @@ -1072,6 +1072,7 @@ lazy val mainProj = (project in file("main"))
exclude[IncompatibleTemplateDefProblem]("sbt.internal.server.BuildServerReporter"),
exclude[MissingClassProblem]("sbt.internal.CustomHttp*"),
exclude[ReversedMissingMethodProblem]("sbt.JobHandle.isAutoCancel"),
exclude[ReversedMissingMethodProblem]("sbt.BackgroundJobService.createWorkingDirectory"),
)
)
.configure(
Expand Down
27 changes: 15 additions & 12 deletions main-command/src/main/scala/sbt/internal/CommandChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ abstract class CommandChannel {
}
}
}
protected def appendExec(commandLine: String, execId: Option[String]): Boolean =
append(Exec(commandLine, execId.orElse(Some(Exec.newExecId)), Some(CommandSource(name))))
def poll: Option[Exec] = Option(commandQueue.poll)

def prompt(e: ConsolePromptEvent): Unit = userThread.onConsolePromptEvent(e)
Expand All @@ -81,20 +83,21 @@ abstract class CommandChannel {
private[sbt] final def logLevel: Level.Value = level.get
private[this] def setLevel(value: Level.Value, cmd: String): Boolean = {
level.set(value)
append(Exec(cmd, Some(Exec.newExecId), Some(CommandSource(name))))
appendExec(cmd, None)
}
private[sbt] def onCommand: String => Boolean = {
case "error" => setLevel(Level.Error, "error")
case "debug" => setLevel(Level.Debug, "debug")
case "info" => setLevel(Level.Info, "info")
case "warn" => setLevel(Level.Warn, "warn")
case cmd =>
if (cmd.nonEmpty) append(Exec(cmd, Some(Exec.newExecId), Some(CommandSource(name))))
else false
}
private[sbt] def onFastTrackTask: String => Boolean = { s: String =>
private[sbt] def onCommandLine(cmd: String): Boolean =
cmd match {
case "error" => setLevel(Level.Error, "error")
case "debug" => setLevel(Level.Debug, "debug")
case "info" => setLevel(Level.Info, "info")
case "warn" => setLevel(Level.Warn, "warn")
case cmd =>
if (cmd.nonEmpty) appendExec(cmd, None)
else false
}
private[sbt] def onFastTrackTask(cmd: String): Boolean = {
fastTrack.synchronized(fastTrack.forEach { q =>
q.add(new FastTrackTask(this, s))
q.add(new FastTrackTask(this, cmd))
()
})
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ import java.text.DateFormat
import sbt.BasicCommandStrings.{ DashDashDetachStdio, DashDashServer, Shutdown, TerminateAction }
import sbt.internal.client.NetworkClient.Arguments
import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams }
import sbt.internal.worker.{ ClientJobParams, JvmRunInfo, NativeRunInfo, RunInfo }
import sbt.internal.protocol._
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Signals, Terminal, Util }
import sbt.internal.util.{
ConsoleAppender,
ConsoleOut,
MessageOnlyException,
Signals,
Terminal,
Util
}
import sbt.io.IO
import sbt.io.syntax._
import sbt.protocol._
Expand All @@ -43,6 +51,7 @@ import Serialization.{
attach,
cancelReadSystemIn,
cancelRequest,
clientJob,
promptChannel,
readSystemIn,
systemIn,
Expand All @@ -63,6 +72,7 @@ import Serialization.{
}
import NetworkClient.Arguments
import java.util.concurrent.TimeoutException
import sbt.util.Logger

trait ConsoleInterface {
def appendLog(level: Level.Value, message: => String): Unit
Expand Down Expand Up @@ -166,6 +176,11 @@ class NetworkClient(
case null => inputThread.set(new RawInputThread)
case _ =>
}
private lazy val log: Logger = new Logger {
def trace(t: => Throwable): Unit = ()
def success(message: => String): Unit = ()
def log(level: Level.Value, message: => String): Unit = console.appendLog(level, message)
}

private[sbt] def connectOrStartServerAndConnect(
promptCompleteUsers: Boolean,
Expand Down Expand Up @@ -295,7 +310,18 @@ class NetworkClient(
}
// initiate handshake
val execId = UUID.randomUUID.toString
val initCommand = InitCommand(tkn, Option(execId), Some(true))
val skipAnalysis = true
val opts = InitializeOption(
token = tkn,
skipAnalysis = Some(skipAnalysis),
canWork = Some(true),
)
val initCommand = InitCommand(
token = tkn, // duplicated with opts for compatibility
execId = Option(execId),
skipAnalysis = Some(skipAnalysis), // duplicated with opts for compatibility
initializationOptions = Some(opts),
)
conn.sendString(Serialization.serializeCommandAsJsonMessage(initCommand))
connectionHolder.set(conn)
conn
Expand Down Expand Up @@ -641,6 +667,12 @@ class NetworkClient(
case Success(params) => splitDiagnostics(params); Vector()
case Failure(_) => Vector()
}
case (`clientJob`, Some(json)) =>
import sbt.internal.worker.codec.JsonProtocol._
Converter.fromJson[ClientJobParams](json) match {
case Success(params) => clientSideRun(params).get; Vector.empty
case Failure(_) => Vector.empty
}
case (`Shutdown`, Some(_)) => Vector.empty
case (msg, _) if msg.startsWith("build/") => Vector.empty
case _ =>
Expand Down Expand Up @@ -687,6 +719,58 @@ class NetworkClient(
}
}

private def clientSideRun(params: ClientJobParams): Try[Unit] =
params.runInfo match {
case Some(info) => clientSideRun(info)
case _ => Failure(new MessageOnlyException(s"runInfo is not specified in $params"))
}

private def clientSideRun(runInfo: RunInfo): Try[Unit] = {
def jvmRun(info: JvmRunInfo): Try[Unit] = {
val option = ForkOptions(
javaHome = info.javaHome.map(new File(_)),
outputStrategy = None, // TODO: Handle buffered output etc
bootJars = Vector.empty,
workingDirectory = info.workingDirectory.map(new File(_)),
runJVMOptions = info.jvmOptions,
connectInput = info.connectInput,
envVars = info.environmentVariables,
)
// ForkRun handles exit code handling and cancellation
val runner = new ForkRun(option)
runner
.run(
mainClass = info.mainClass,
classpath = info.classpath.map(_.path).map(new File(_)),
options = info.args,
log = log
)
}
def nativeRun(info: NativeRunInfo): Try[Unit] = {
import java.lang.{ ProcessBuilder => JProcessBuilder }
val option = ForkOptions(
javaHome = None,
outputStrategy = None, // TODO: Handle buffered output etc
bootJars = Vector.empty,
workingDirectory = info.workingDirectory.map(new File(_)),
runJVMOptions = Vector.empty,
connectInput = info.connectInput,
envVars = info.environmentVariables,
)
val command = info.cmd :: info.args.toList
val jpb = new JProcessBuilder(command: _*)
val exitCode = try Fork.blockForExitCode(Fork.forkInternal(option, Nil, jpb))
catch {
case _: InterruptedException =>
log.warn("run canceled")
1
}
Run.processExitCode(exitCode, "runner")
}
if (runInfo.jvm) jvmRun(runInfo.jvmRunInfo.getOrElse(sys.error("missing jvmRunInfo")))
else nativeRun(runInfo.nativeRunInfo.getOrElse(sys.error("missing nativeRunInfo")))
}

def onRequest(msg: JsonRpcRequestMessage): Unit = {
import sbt.protocol.codec.JsonProtocol._
(msg.method, msg.params) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ trait ServerCallback {
private[sbt] def authOptions: Set[ServerAuthentication]
private[sbt] def authenticate(token: String): Boolean
private[sbt] def setInitialized(value: Boolean): Unit
private[sbt] def setInitializeOption(opts: InitializeOption): Unit
private[sbt] def onSettingQuery(execId: Option[String], req: Q): Unit
private[sbt] def onCompletionRequest(execId: Option[String], cp: CP): Unit
private[sbt] def onCancellationRequest(execId: Option[String], crp: CRP): Unit
Expand Down
27 changes: 17 additions & 10 deletions main-command/src/main/scala/sbt/internal/ui/UITask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private[sbt] trait UITask extends Runnable with AutoCloseable {
private[sbt] val reader: UITask.Reader
private[this] final def handleInput(s: Either[String, String]): Boolean = s match {
case Left(m) => channel.onFastTrackTask(m)
case Right(cmd) => channel.onCommand(cmd)
case Right(cmd) => channel.onCommandLine(cmd)
}
private[this] val isStopped = new AtomicBoolean(false)
override def run(): Unit = {
Expand Down Expand Up @@ -56,6 +56,20 @@ private[sbt] object UITask {
object Reader {
// Avoid filling the stack trace since it isn't helpful here
object interrupted extends InterruptedException

/**
* Return Left for fast track commands, otherwise return Right(...).
*/
def splitCommand(cmd: String): Either[String, String] =
// We need to put the empty string on the fast track queue so that we can
// reprompt the user if another command is running on the server.
if (cmd.isEmpty()) Left("")
else
cmd match {
case Shutdown | TerminateAction | Cancel => Left(cmd)
case cmd => Right(cmd)
}

def terminalReader(parser: Parser[_])(
terminal: Terminal,
state: State
Expand All @@ -78,15 +92,8 @@ private[sbt] object UITask {
Right("") // should be unreachable
// JLine returns null on ctrl+d when there is no other input. This interprets
// ctrl+d with no imput as an exit
case None => Left(TerminateAction)
case Some(s: String) =>
s.trim() match {
// We need to put the empty string on the fast track queue so that we can
// reprompt the user if another command is running on the server.
case "" => Left("")
case cmd @ (`Shutdown` | `TerminateAction` | `Cancel`) => Left(cmd)
case cmd => Right(cmd)
}
case None => Left(TerminateAction)
case Some(s: String) => splitCommand(s.trim())
}
}
terminal.setPrompt(Prompt.Pending)
Expand Down
2 changes: 2 additions & 0 deletions main/src/main/scala/sbt/BackgroundJobService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ abstract class BackgroundJobService extends Closeable {

def waitFor(job: JobHandle): Unit

private[sbt] def createWorkingDirectory: File

/** Copies classpath to temporary directories. */
def copyClasspath(products: Classpath, full: Classpath, workingDirectory: File): Classpath

Expand Down
5 changes: 3 additions & 2 deletions main/src/main/scala/sbt/Defaults.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import sbt.internal.server.{
BspCompileTask,
BuildServerProtocol,
BuildServerReporter,
ClientJob,
Definition,
LanguageServerProtocol,
ServerHandler,
Expand Down Expand Up @@ -222,7 +223,7 @@ object Defaults extends BuildCommon {
closeClassLoaders :== SysProp.closeClassLoaders,
allowZombieClassLoaders :== true,
packageTimestamp :== Package.defaultTimestamp,
) ++ BuildServerProtocol.globalSettings
) ++ BuildServerProtocol.globalSettings ++ ClientJob.globalSettings

private[sbt] lazy val globalIvyCore: Seq[Setting[_]] =
Seq(
Expand Down Expand Up @@ -2717,7 +2718,7 @@ object Defaults extends BuildCommon {
lazy val configSettings: Seq[Setting[_]] =
Classpaths.configSettings ++ configTasks ++ configPaths ++ packageConfig ++
Classpaths.compilerPluginConfig ++ deprecationSettings ++
BuildServerProtocol.configSettings
BuildServerProtocol.configSettings ++ ClientJob.configSettings

lazy val compileSettings: Seq[Setting[_]] =
configSettings ++ (mainBgRunMainTask +: mainBgRunTask) ++ Classpaths.addUnmanagedLibrary
Expand Down
3 changes: 3 additions & 0 deletions main/src/main/scala/sbt/Keys.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import sbt.internal.remotecache.RemoteCacheArtifact
import sbt.internal.server.BuildServerProtocol.BspFullWorkspace
import sbt.internal.server.{ BuildServerReporter, ServerHandler }
import sbt.internal.util.{ AttributeKey, ProgressState, SourcePosition }
import sbt.internal.worker.ClientJobParams
import sbt.io._
import sbt.librarymanagement.Configurations.CompilerPlugin
import sbt.librarymanagement.LibraryManagementCodec._
Expand Down Expand Up @@ -437,6 +438,8 @@ object Keys {
val bspScalaMainClasses = inputKey[Unit]("Implementation of buildTarget/scalaMainClasses").withRank(DTask)
val bspScalaMainClassesItem = taskKey[ScalaMainClassesItem]("").withRank(DTask)
val bspReporter = taskKey[BuildServerReporter]("").withRank(DTask)
val clientJob = inputKey[ClientJobParams]("Translates a task into a job specification").withRank(Invisible)
val clientJobRunInfo = inputKey[ClientJobParams]("Translates the run task into a job specification").withRank(Invisible)

val useCoursier = settingKey[Boolean]("Use Coursier for dependency resolution.").withRank(BSetting)
val csrCacheDirectory = settingKey[File]("Coursier cache directory. Uses -Dsbt.coursier.home or Coursier's default.").withRank(CSetting)
Expand Down
13 changes: 11 additions & 2 deletions main/src/main/scala/sbt/internal/DefaultBackgroundJobService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ private[sbt] abstract class AbstractBackgroundJobService extends BackgroundJobSe
override val isAutoCancel = false
}

private[sbt] def createWorkingDirectory: File = {
val id = nextId.getAndIncrement()
createWorkingDirectory(id)
}
private[sbt] def createWorkingDirectory(id: Long): File = {
val workingDir = serviceTempDir / s"job-$id"
IO.createDirectory(workingDir)
workingDir
}

def doRunInBackground(
spawningTask: ScopedKey[_],
state: State,
Expand All @@ -153,8 +163,7 @@ private[sbt] abstract class AbstractBackgroundJobService extends BackgroundJobSe
val extracted = Project.extract(state)
val logger =
LogManager.constructBackgroundLog(extracted.structure.data, state, context)(spawningTask)
val workingDir = serviceTempDir / s"job-$id"
IO.createDirectory(workingDir)
val workingDir = createWorkingDirectory(id)
val job = try {
new ThreadJobHandle(id, spawningTask, logger, workingDir, start(logger, workingDir))
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ object BuildServerProtocol {
state.respondEvent(result)
}
}.evaluated,
bspScalaMainClasses / aggregate := false
bspScalaMainClasses / aggregate := false,
)

// This will be scoped to Compile, Test, IntegrationTest etc
Expand Down Expand Up @@ -345,7 +345,7 @@ object BuildServerProtocol {
} else {
new BuildServerForwarder(meta, logger, underlying)
}
}
},
)
private[sbt] object Method {
final val Initialize = "build/initialize"
Expand Down
Loading

0 comments on commit 6726563

Please sign in to comment.