Skip to content

Commit

Permalink
Support cross- and multi-source projects with git
Browse files Browse the repository at this point in the history
  • Loading branch information
Albert Meltzer committed Dec 24, 2021
1 parent b6f1a9f commit 43ed34a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 34 deletions.
80 changes: 47 additions & 33 deletions plugin/src/main/scala/org/scalafmt/sbt/ScalafmtPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,17 @@ object ScalafmtPlugin extends AutoPlugin {
@inline private def asRelative(file: File): String =
baseDir.relativize(file.getCanonicalFile.toPath).toString

private def filterFiles(sources: Seq[File]): Seq[File] = {
val filter = getFileFilter()
private def filterFiles(sources: Seq[File], dirs: Seq[File]): Seq[File] = {
val filter = getFileFilter(dirs)
sources.map(_.getCanonicalFile).distinct.filter { file =>
val path = file.toPath
scalafmtSession.matchesProjectFilters(path) && filter(path)
}
}

private def getFileFilter(): Path => Boolean = {
private def getFileFilter(dirs: Seq[File]): Path => Boolean = {
// dirs don't have to be within baseDir but within the same git tree
def absDirs = dirs.map(x => AbsoluteFile(x.getCanonicalFile.toPath))
def gitOps = GitOps.FactoryImpl(AbsoluteFile(baseDir))
def getFromFiles(getFiles: => Seq[AbsoluteFile], gitCmd: => String) = {
def gitMessage = s"[git $gitCmd] ($baseDir)"
Expand All @@ -199,12 +201,12 @@ object ScalafmtPlugin extends AutoPlugin {
}

if (filterMode == FilterMode.diffDirty)
getFromFiles(gitOps.status(), "status")
getFromFiles(gitOps.status(absDirs: _*), "status")
else if (filterMode.startsWith(FilterMode.diffRefPrefix)) {
val branch = filterMode.substring(FilterMode.diffRefPrefix.length)
getFromFiles(gitOps.diff(branch), s"diff $branch")
getFromFiles(gitOps.diff(branch, absDirs: _*), s"diff $branch")
} else if (filterMode != FilterMode.none && scalafmtSession.isGitOnly)
getFromFiles(gitOps.lsTree(), "ls-files")
getFromFiles(gitOps.lsTree(absDirs: _*), "ls-files")
else {
log.debug("considering all files (no git)")
_ => true
Expand Down Expand Up @@ -244,8 +246,8 @@ object ScalafmtPlugin extends AutoPlugin {
res
}

def formatTrackedSources(sources: Seq[File]): Unit = {
val filteredSources = filterFiles(sources)
def formatTrackedSources(sources: Seq[File], dirs: Seq[File]): Unit = {
val filteredSources = filterFiles(sources, dirs)
trackSourcesAndConfig(cacheStoreFactory, filteredSources) {
(outDiff, configChanged, prev) =>
val filesToFormat: Seq[File] =
Expand All @@ -261,8 +263,8 @@ object ScalafmtPlugin extends AutoPlugin {
}
}

def formatSources(sources: Seq[File]): Unit =
formatFilteredSources(filterFiles(sources))
def formatSources(sources: Seq[File], dirs: Seq[File]): Unit =
formatFilteredSources(filterFiles(sources, dirs))

private def formatFilteredSources(sources: Seq[File]): Unit = {
if (sources.nonEmpty)
Expand All @@ -274,8 +276,8 @@ object ScalafmtPlugin extends AutoPlugin {
if (cnt > 0) log.info(s"Reformatted $cnt Scala sources")
}

def checkTrackedSources(sources: Seq[File]): Unit = {
val filteredSources = filterFiles(sources)
def checkTrackedSources(sources: Seq[File], dirs: Seq[File]): Unit = {
val filteredSources = filterFiles(sources, dirs)
val result = trackSourcesAndConfig(cacheStoreFactory, filteredSources) {
(outDiff, configChanged, prev) =>
val filesToCheck: Seq[File] =
Expand All @@ -300,8 +302,8 @@ object ScalafmtPlugin extends AutoPlugin {
throwOnFailure(result)
}

def checkSources(sources: Seq[File]): Unit =
throwOnFailure(checkFilteredSources(filterFiles(sources)))
def checkSources(sources: Seq[File], dirs: Seq[File]): Unit =
throwOnFailure(checkFilteredSources(filterFiles(sources, dirs)))

private def checkFilteredSources(sources: Seq[File]): ScalafmtAnalysis = {
if (sources.nonEmpty) {
Expand Down Expand Up @@ -393,57 +395,69 @@ object ScalafmtPlugin extends AutoPlugin {
}
}

private def scalafmtTask(sources: Seq[File], session: FormatSession) =
private def scalafmtTask(
sources: Seq[File],
dirs: Seq[File],
session: FormatSession
) =
Def.task {
session.formatTrackedSources(sources)
session.formatTrackedSources(sources, dirs)
} tag (ScalafmtTagPack: _*)

private def scalafmtCheckTask(sources: Seq[File], session: FormatSession) =
private def scalafmtCheckTask(
sources: Seq[File],
dirs: Seq[File],
session: FormatSession
) =
Def.task {
session.checkTrackedSources(sources)
session.checkTrackedSources(sources, dirs)
} tag (ScalafmtTagPack: _*)

private def getScalafmtSourcesTask(
f: (Seq[File], FormatSession) => InitTask
f: (Seq[File], Seq[File], FormatSession) => InitTask
) = Def.taskDyn[Unit] {
val sources = (unmanagedSources in scalafmt).?.value.getOrElse(Seq.empty)
getScalafmtTask(f)(sources, scalaConfig.value)
val dirs = (unmanagedSourceDirectories in scalafmt).?.value.getOrElse(Nil)
getScalafmtTask(f)(sources, dirs, scalaConfig.value)
}

private def scalafmtSbtTask(
sources: Seq[File],
dirs: Seq[File],
session: FormatSession
) = Def.task {
session.formatSources(sources)
session.formatSources(sources, dirs)
} tag (ScalafmtTagPack: _*)

private def scalafmtSbtCheckTask(
sources: Seq[File],
dirs: Seq[File],
session: FormatSession
) = Def.task {
session.checkSources(sources)
session.checkSources(sources, dirs)
} tag (ScalafmtTagPack: _*)

private def getScalafmtSbtTasks(
func: (Seq[File], FormatSession) => InitTask
func: (Seq[File], Seq[File], FormatSession) => InitTask
) = Def.taskDyn {
joinScalafmtTasks(func)(
(sbtSources.value, sbtConfig.value),
(metabuildSources.value, scalaConfig.value)
(sbtSources.value, Nil, sbtConfig.value),
(metabuildSources.value, Nil, scalaConfig.value)
)
}

private def joinScalafmtTasks(
func: (Seq[File], FormatSession) => InitTask
)(tuples: (Seq[File], Path)*) = {
val tasks = tuples
.map { case (files, config) => getScalafmtTask(func)(files, config) }
func: (Seq[File], Seq[File], FormatSession) => InitTask
)(tuples: (Seq[File], Seq[File], Path)*) = {
val tasks = tuples.map { case (files, dirs, config) =>
getScalafmtTask(func)(files, dirs, config)
}
Def.sequential(tasks.tail.toList, tasks.head)
}

private def getScalafmtTask(
func: (Seq[File], FormatSession) => InitTask
)(files: Seq[File], config: Path) = Def.taskDyn[Unit] {
func: (Seq[File], Seq[File], FormatSession) => InitTask
)(files: Seq[File], dirs: Seq[File], config: Path) = Def.taskDyn[Unit] {
if (files.isEmpty) Def.task(Unit)
else {
val session = new FormatSession(
Expand All @@ -460,7 +474,7 @@ object ScalafmtPlugin extends AutoPlugin {
scalafmtDetailedError.value
)
)
func(files, session)
func(files, dirs, session)
}
}

Expand Down Expand Up @@ -505,7 +519,7 @@ object ScalafmtPlugin extends AutoPlugin {
scalafmtFailOnErrors.value,
scalafmtDetailedError.value
)
).formatSources(absFiles)
).formatSources(absFiles, Nil)
}
)

Expand Down
2 changes: 1 addition & 1 deletion plugin/src/sbt-test/scalafmt-sbt/sbt/test
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ $ exec git -C p19 add "jvm/src/main/scala/TestGood.scala"
> p19/scalafmtCheck
$ copy-file changes/invalid.scala p19/shared/src/main/scala/TestInvalid1.scala
$ exec git -C p19 add "shared/src/main/scala/TestInvalid1.scala"
> p19/scalafmtCheck
-> p19/scalafmtCheck

$ copy-file changes/target/managed.scala project/target/managed.scala
$ copy-file changes/x/Something.scala project/x/Something.scala
Expand Down

0 comments on commit 43ed34a

Please sign in to comment.