Skip to content

Commit

Permalink
ScalafmtConfig: add methods to detect rewrites
Browse files Browse the repository at this point in the history
Previously, it was sufficient to check that `rules` was non-empty; but
it is no longer enough as some rewrites are triggered via other params.
  • Loading branch information
kitbellew committed Nov 22, 2021
1 parent 9853e85 commit ad9da7f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ case class RewriteSettings(
}

object RewriteSettings {

val default = RewriteSettings()

implicit lazy val surface: generic.Surface[RewriteSettings] =
generic.deriveSurface
implicit lazy val encoder: ConfEncoder[RewriteSettings] =
generic.deriveEncoder

implicit lazy val decoder: ConfDecoderEx[RewriteSettings] =
generic.deriveDecoderEx(RewriteSettings()).noTypos.flatMap {
generic.deriveDecoderEx(default).noTypos.flatMap {
Imports.validateImports
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import scala.meta.Dialect
import scala.util.Try

import metaconfig._
import org.scalafmt.rewrite.FormatTokensRewrite
import org.scalafmt.sysops.AbsoluteFile
import org.scalafmt.sysops.FileOps
import org.scalafmt.sysops.OsSpecific._
Expand Down Expand Up @@ -98,7 +99,7 @@ case class ScalafmtConfig(
literals: Literals = Literals(),
lineEndings: LineEndings = LineEndings.unix,
rewriteTokens: Map[String, String] = Map.empty[String, String],
rewrite: RewriteSettings = RewriteSettings(),
rewrite: RewriteSettings = RewriteSettings.default,
indentOperator: IndentOperator = IndentOperator(),
newlines: Newlines = Newlines(),
runner: ScalafmtRunner = ScalafmtRunner.default,
Expand Down Expand Up @@ -186,6 +187,19 @@ case class ScalafmtConfig(
private[scalafmt] lazy val dialect = runner.getDialect

private[scalafmt] def getTrailingCommas = rewrite.trailingCommas.style

// used in ScalafmtReflectConfig
def hasRewrites: Boolean = {
rewrite.rewriteFactoryRules.nonEmpty ||
FormatTokensRewrite.getEnabledFactories(this).nonEmpty
}

// used in ScalafmtReflectConfig
def withoutRewrites: ScalafmtConfig = copy(
trailingCommas = None,
rewrite = RewriteSettings.default
)

}

object ScalafmtConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,13 @@ import scala.util.Try
class ScalafmtReflectConfig private[dynamic] (val fmtReflect: ScalafmtReflect)(
private[dynamic] val target: Object
) {
import ScalafmtReflectConfig._
import fmtReflect.classLoader
private val targetCls = target.getClass
private val constructor: Constructor[_] = targetCls.getConstructors()(0)
private val constructorParams = constructor.getParameters.map(_.getName)
private val rewriteParamIdx =
constructorParams.indexOf("rewrite").ensuring(_ >= 0)
private val emptyRewrites =
target.invoke("apply$default$" + (rewriteParamIdx + 1))

private val dialectCls = classLoader.loadClass("scala.meta.Dialect")
private val dialectsCls = classLoader.loadClass("scala.meta.dialects.package")

private val rewriteRulesMethod = Try(targetCls.getMethod("rewrite")).toOption

@inline def getVersion = fmtReflect.version

def isIncludedInProject(filename: String): Boolean = {
Expand All @@ -51,33 +44,48 @@ class ScalafmtReflectConfig private[dynamic] (val fmtReflect: ScalafmtReflect)(
)

def withoutRewriteRules: ScalafmtReflectConfig = {
if (hasRewriteRules) {
if (getVersion < ScalafmtVersion(3, 2, 0)) withoutRewriteRulesPre320
else new ScalafmtReflectConfig(fmtReflect)(target.invoke("withoutRewrites"))
}

private def withoutRewriteRulesPre320: ScalafmtReflectConfig =
if (!hasRewriteRulesPre320) this
else {
// emulating this.copy(rewrite = RewriteSettings())
val constructor: Constructor[_] = targetCls.getConstructors()(0)
val constructorParams = constructor.getParameters.map(_.getName)
val rewriteParamIdx =
constructorParams.indexOf(rewriteFieldName).ensuring(_ >= 0)
val emptyRewrites =
target.invoke("apply$default$" + (rewriteParamIdx + 1))
val fieldsValues = constructorParams.map(param => target.invoke(param))
fieldsValues(rewriteParamIdx) = emptyRewrites
new ScalafmtReflectConfig(fmtReflect)(
constructor.newInstance(fieldsValues: _*).asInstanceOf[Object]
)
} else {
this
}
}

def hasRewriteRules: Boolean = {
rewriteRulesMethod match {
case Some(method) =>
// scalafmt >= v0.4.1
val rewriteSettings = method.invoke(target)
!rewriteSettings.invoke("rules").invokeAs[Boolean]("isEmpty")
case None =>
false
}
if (getVersion < ScalafmtVersion(3, 2, 0)) hasRewriteRulesPre320
else target.invokeAs[Boolean]("hasRewrites")
}

private def hasRewriteRulesPre320: Boolean = // scalafmt >= v0.4.1
Try {
val rules = target.invoke(rewriteFieldName).invoke("rules")
!rules.invokeAs[Boolean]("isEmpty")
}.getOrElse(false)

def format(code: String, file: Option[Path]): String =
fmtReflect.format(code, this, file)

override def equals(obj: Any): Boolean = target.equals(obj)

override def hashCode(): Int = target.hashCode()
}

private object ScalafmtReflectConfig {

val rewriteFieldName = "rewrite"

}

0 comments on commit ad9da7f

Please sign in to comment.