Skip to content

Commit

Permalink
ApplyInfix ApplyUnary ApplyUsing
Browse files Browse the repository at this point in the history
  • Loading branch information
zmerr committed Jun 9, 2022
1 parent a073adf commit 12e9c7c
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,83 @@ class FlatMapToForComprehensionCodeAction(
tree: Term,
newName: Term.Name,
allowedToGetInsideApply: Boolean
): Term =
): (Term, Int) =
tree match {
case Term.Apply(fun, args) if allowedToGetInsideApply =>
Term.Apply(
replacePlaceHolder(fun, newName, false),
args.map(replacePlaceHolder(_, newName, false))
)
val (newFun, funReplacementTimes) =
replacePlaceHolder(fun, newName, false)
val (newArgs, argsReplacementTimes) =
args.map(replacePlaceHolder(_, newName, false)).unzip

val replacementTimes =
argsReplacementTimes.fold(funReplacementTimes)((result, newElem) =>
result + newElem
)

(Term.Apply(newFun, newArgs), replacementTimes)

case Term.ApplyUnary(op, arg) if allowedToGetInsideApply =>
val (newArg, argReplacementTimes) =
replacePlaceHolder(arg, newName, allowedToGetInsideApply)
(Term.ApplyUnary(op, newArg), argReplacementTimes)

case Term.ApplyUsing(fun, args) if allowedToGetInsideApply =>
val (newFun, funReplacementTimes) =
replacePlaceHolder(fun, newName, false) // TODO false??
val (newArgs, argsReplacementTimes) =
args.map(replacePlaceHolder(_, newName, false)).unzip // TODO false??

val replacementTimes =
argsReplacementTimes.fold(funReplacementTimes)((result, newElem) =>
result + newElem
)

(Term.ApplyUsing(newFun, newArgs), replacementTimes)

case Term.ApplyInfix(lhs, op, targs, args) if allowedToGetInsideApply =>
val (newLHS, lhsReplacementTimes) =
replacePlaceHolder(lhs, newName, allowedToGetInsideApply)
val (newArgs, argsReplacementTimes) =
args
.map(replacePlaceHolder(_, newName, allowedToGetInsideApply))
.unzip

val replacementTimes =
argsReplacementTimes.fold(lhsReplacementTimes)((result, newElem) =>
result + newElem
)

(Term.ApplyInfix(newLHS, op, targs, newArgs), replacementTimes)

case Term.Select(qual, name) =>
Term.Select(
replacePlaceHolder(qual, newName, allowedToGetInsideApply),
name
val (newQual, qualReplacementTimes) =
replacePlaceHolder(qual, newName, allowedToGetInsideApply)
(
Term.Select(
newQual,
name
),
qualReplacementTimes
)
case Term.Placeholder() => newName
case other => other
case Term.Placeholder() => (newName, 1)
case other => (other, 0)
}

private def replacePlaceHolderInTermWithNewName(
term: Term,
generatedByMetalsValues: Set[String]
): (Option[(String, Term)], Set[String]) = {
val (newName, generatedValues) =
createNewName(term, generatedByMetalsValues)

val (newTerm, replacementTimes) =
replacePlaceHolder(term, Term.Name(newName), true)
(
if (replacementTimes == 1) Some((newName, newTerm)) else None,
generatedValues
)
}

private def processValueNameAndNextQual(
tree: Tree,
generatedByMetalsValues: Set[String]
Expand All @@ -170,21 +231,36 @@ class FlatMapToForComprehensionCodeAction(
(Some((param.name.value, term)), generatedByMetalsValues)

case Term.AnonymousFunction(term) =>
replacePlaceHolderInTermWithNewName(term, generatedByMetalsValues)
case term: Term.ApplyInfix =>
replacePlaceHolderInTermWithNewName(term, generatedByMetalsValues)
case term: Term.ApplyUnary =>
replacePlaceHolderInTermWithNewName(term, generatedByMetalsValues)
case term: Term.ApplyUsing =>
replacePlaceHolderInTermWithNewName(term, generatedByMetalsValues)
case Term.Block(List(function)) =>
processValueNameAndNextQual(function, generatedByMetalsValues)
case term: Term.Name =>
val (newName, generatedValues) =
createNewName(tree, generatedByMetalsValues)
(
Some((newName, replacePlaceHolder(term, Term.Name(newName), true))),
Some((newName, Term.Apply(term, List(Term.Name(newName))))),
generatedValues
)
case Term.Block(List(function)) =>
processValueNameAndNextQual(function, generatedByMetalsValues)
case term: Term =>

case term: Term.Select =>
val (newName, generatedValues) =
createNewName(tree, generatedByMetalsValues)
(
Some((newName, Term.Apply(term, List(Term.Name(newName))))),
generatedValues
)

case _ =>
(
None,
generatedByMetalsValues
)
}
}

Expand Down Expand Up @@ -232,54 +308,35 @@ class FlatMapToForComprehensionCodeAction(
}

private def processMap(
perhapseValueNameAndNextQual: Option[(String, Term)],
newMetalsNames: Set[String],
generatedByMetalsNames: Set[String],
perhapseLastName: Option[String],
shouldFlat: Boolean,
existingForElements: List[Enumerator],
currentYieldTerm: Option[Term],
termApply: Term.Apply,
elems: List[Enumerator],
maybeYieldTerm: Option[Term],
updatedGenByMetalsVals: Set[String],
valueName: String,
termSelectQual: Term
): (List[Enumerator], Option[Term], Set[String]) = {
val result = for {
valueName <- perhapseValueNameAndNextQual.map(_._1)
nextQual <- perhapseValueNameAndNextQual.map(_._2)
} yield {
val (elems, maybeYieldTerm, updatedGenByMetalsVals) =
obtainNextYieldAndElemsForMap(
newMetalsNames,
perhapseLastName,
shouldFlat,
existingForElements,
currentYieldTerm,
termApply,
nextQual
)

termSelectQual match {
case qualTermApply: Term.Apply =>
extractChainedForYield(
Some(valueName),
maybeYieldTerm,
elems,
qualTermApply,
updatedGenByMetalsVals
)
case otherQual =>
(
Enumerator.Generator(
Pat.Var.apply(Term.Name.apply(valueName)),
otherQual
)
+: elems,
maybeYieldTerm,
updatedGenByMetalsVals
termSelectQual match {
case qualTermApply: Term.Apply =>
extractChainedForYield(
Some(valueName),
maybeYieldTerm,
elems,
qualTermApply,
updatedGenByMetalsVals
)
case otherQual =>
(
Enumerator.Generator(
Pat.Var.apply(Term.Name.apply(valueName)),
otherQual
)
+: elems,
maybeYieldTerm,
updatedGenByMetalsVals
)

}
}
result.getOrElse(List.empty, None, generatedByMetalsNames)

}

private def processFilter(
Expand Down Expand Up @@ -360,17 +417,31 @@ class FlatMapToForComprehensionCodeAction(
case termSelect: Term.Select
if termSelect.name.value == "flatMap" || termSelect.name.value == "map" =>
val shouldFlat = termSelect.name.value == "flatMap"
processMap(
perhapseValueNameAndNextQual,
newMetalsNames,
generatedByMetalsValues,
perhapseLastName,
shouldFlat,
existingForElements,
currentYieldTerm,
termApply,
termSelect.qual
)

val result = for {
valueName <- perhapseValueNameAndNextQual.map(_._1)
nextQual <- perhapseValueNameAndNextQual.map(_._2)
} yield {
val (elems, maybeYieldTerm, updatedGenByMetalsVals) =
obtainNextYieldAndElemsForMap(
newMetalsNames,
perhapseLastName,
shouldFlat,
existingForElements,
currentYieldTerm,
termApply,
nextQual
)

processMap(
elems,
maybeYieldTerm,
updatedGenByMetalsVals,
valueName,
termSelect.qual
)
}
result.getOrElse(List.empty, None, generatedByMetalsValues)

case termSelect: Term.Select
if termSelect.name.value == "filter" || termSelect.name.value == "filterNot" ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class FlatMapToForComprehensionSuite
| def double(x : Int, y: Int = 1) = y * x
| def check(x: Int) = true
| val list = List(1, 2, 3)
| def negate(a: Boolean) = !a
|
| val res3 = list
| .flatMap{
Expand All @@ -25,38 +26,49 @@ class FlatMapToForComprehensionSuite
| .filterNot(_ => true)
| .map(_ => 7)
| .map(c => c - 1)
| .map(5 + double(_, 7).toFloat.toInt / 8 + 6)
| .filter(d => d > 1)
| .map(double(_, 5))
| .map( double(_, 4).toFloat.toDouble)
| .m<<>>ap( _.toInt.compare(3))
| .map( _.toInt.compare(3))
| .map(_ > 2)
| .map(!negate(_))
| .m<<>>ap( true && !negate(_) && false)
|
|}
|""".stripMargin,
s"""|${RewriteBracesParensCodeAction.toBraces}
s"""|${RewriteBracesParensCodeAction.toBraces("map")}
|${FlatMapToForComprehensionCodeAction.flatMapToForComprehension}
|""".stripMargin,
"""|object A {
| def double(x : Int, y: Int = 1) = y * x
| def check(x: Int) = true
| val list = List(1, 2, 3)
| def negate(a: Boolean) = !a
|
| val res3 = {
| for {
| a <- list
| generatedByMetals3 <- {
| generatedByMetals8 <- {
| val m = 6
| Some(a + 1).map(b => b + 3 + 4)
| }
| if check(generatedByMetals3)
| generatedByMetals2 = generatedByMetals3
| if check(generatedByMetals8)
| generatedByMetals7 = generatedByMetals8
| if !true
| generatedByMetals1 = generatedByMetals2
| generatedByMetals6 = generatedByMetals7
| c = 7
| d = c - 1
| generatedByMetals5 = c - 1
| d = 5 + double(generatedByMetals5, 7).toFloat.toInt / 8 + 6
| if d > 1
| generatedByMetals0 = d
| generatedByMetals = double(generatedByMetals0, 4).toFloat.toDouble
| generatedByMetals4 = d
| generatedByMetals3 = double(generatedByMetals4, 5)
| generatedByMetals2 = double(generatedByMetals3, 4).toFloat.toDouble
| generatedByMetals1 = generatedByMetals2.toInt.compare(3)
| generatedByMetals0 = generatedByMetals1 > 2
| generatedByMetals = !negate(generatedByMetals0)
| } yield {
| generatedByMetals.toInt.compare(3)
| true && !negate(generatedByMetals) && false
| }
| }
|
Expand All @@ -83,14 +95,15 @@ class FlatMapToForComprehensionSuite
| .filterNot(_ => true)
| .map(_ => 7)
| .map(c => c - 1)
| .map( double(_, 4).toFloat.toInt)
| .filter(d => d > 1)
| .map( double(_, 4).toFloat.toDouble)
| .map(5 + double(_, 7).toFloat.toInt / 8 + 6)
| .map( _.toInt.compare(3))
| .fl<<>>atMap( m => Some(m * 3))
|
|}
|""".stripMargin,
s"""|${RewriteBracesParensCodeAction.toBraces}
s"""|${RewriteBracesParensCodeAction.toBraces("flatMap")}
|${FlatMapToForComprehensionCodeAction.flatMapToForComprehension}
|""".stripMargin,
"""|object A {
Expand All @@ -101,19 +114,20 @@ class FlatMapToForComprehensionSuite
| val res3 = {
| for {
| a <- list
| generatedByMetals4 <- {
| generatedByMetals5 <- {
| val m = 6
| Some(a + 1).map(b => b + 3 + 4)
| }
| if check(generatedByMetals4)
| generatedByMetals3 = generatedByMetals4
| if check(generatedByMetals5)
| generatedByMetals4 = generatedByMetals5
| if !true
| generatedByMetals2 = generatedByMetals3
| generatedByMetals3 = generatedByMetals4
| c = 7
| d = c - 1
| generatedByMetals2 = c - 1
| d = double(generatedByMetals2, 4).toFloat.toInt
| if d > 1
| generatedByMetals1 = d
| generatedByMetals0 = double(generatedByMetals1, 4).toFloat.toDouble
| generatedByMetals0 = 5 + double(generatedByMetals1, 7).toFloat.toInt / 8 + 6
| m = generatedByMetals0.toInt.compare(3)
| generatedByMetals <- Some(m * 3)
| } yield {
Expand Down

0 comments on commit 12e9c7c

Please sign in to comment.