Skip to content

Commit

Permalink
Merge pull request #12670 from dotty-staging/fix-12661
Browse files Browse the repository at this point in the history
Always generate a partial function from a lambda
  • Loading branch information
liufengyun authored Jun 3, 2021
2 parents 92c75ab + 3fad3d3 commit aa5ed33
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 81 deletions.
135 changes: 66 additions & 69 deletions compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package dotty.tools.dotc
package dotty.tools
package dotc
package transform

import core._
Expand All @@ -7,6 +8,7 @@ import MegaPhase._
import SymUtils._
import NullOpsDecorator._
import ast.Trees._
import ast.untpd
import reporting._
import dotty.tools.dotc.util.Spans.Span

Expand Down Expand Up @@ -103,78 +105,73 @@ class ExpandSAMs extends MiniPhase:
* ```
*/
private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree = {
/** An extractor for match, either contained in a block or standalone. */
object PartialFunctionRHS {
def unapply(tree: Tree): Option[Match] = tree match {
case Block(Nil, expr) => unapply(expr)
case m: Match => Some(m)
case _ => None
}
}

val closureDef(anon @ DefDef(_, List(List(param)), _, _)) = tree
anon.rhs match {
case PartialFunctionRHS(pf) =>
val anonSym = anon.symbol
val anonTpe = anon.tpe.widen
val parents = List(
defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
defn.SerializableType)
val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.span)

def overrideSym(sym: Symbol) = sym.copy(
owner = pfSym,
flags = Synthetic | Method | Final | Override,
info = tpe.memberInfo(sym),
coord = tree.span).asTerm.entered
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)

def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = {
val selector = tree.selector
val selectorTpe = selector.tpe.widen
val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD, Synthetic | Case, selectorTpe)
val defaultCase =
CaseDef(
Bind(defaultSym, Underscore(selectorTpe)),
EmptyTree,
defaultValue)
val unchecked = selector.annotated(New(ref(defn.UncheckedAnnot.typeRef)))
cpy.Match(tree)(unchecked, cases :+ defaultCase)
.subst(param.symbol :: Nil, pfParam :: Nil)
// Needed because a partial function can be written as:
// param => param match { case "foo" if foo(param) => param }
// And we need to update all references to 'param'
}

def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = {
val tru = Literal(Constant(true))
def translateCase(cdef: CaseDef) =
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
val paramRef = paramRefss.head.head
val defaultValue = Literal(Constant(false))
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
}

def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = {
val List(paramRef, defaultRef) = paramRefss(1)
def translateCase(cdef: CaseDef) =
cdef.changeOwner(anonSym, applyOrElseFn)
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
}

val constr = newConstructor(pfSym, Synthetic, Nil, Nil).entered
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef))
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))

// The right hand side from which to construct the partial function. This is always a Match.
// If the original rhs is already a Match (possibly in braces), return that.
// Otherwise construct a match `x match case _ => rhs` where `x` is the parameter of the closure.
def partialFunRHS(tree: Tree): Match = tree match
case m: Match => m
case Block(Nil, expr) => partialFunRHS(expr)
case _ =>
val found = tpe.baseType(defn.Function1)
report.error(TypeMismatch(found, tpe), tree.srcPos)
tree
Match(ref(param.symbol),
CaseDef(untpd.Ident(nme.WILDCARD).withType(param.symbol.info), EmptyTree, tree) :: Nil)

val pfRHS = partialFunRHS(anon.rhs)
val anonSym = anon.symbol
val anonTpe = anon.tpe.widen
val parents = List(
defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
defn.SerializableType)
val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.span)

def overrideSym(sym: Symbol) = sym.copy(
owner = pfSym,
flags = Synthetic | Method | Final | Override,
info = tpe.memberInfo(sym),
coord = tree.span).asTerm.entered
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)

def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = {
val selector = tree.selector
val selectorTpe = selector.tpe.widen
val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD, Synthetic | Case, selectorTpe)
val defaultCase =
CaseDef(
Bind(defaultSym, Underscore(selectorTpe)),
EmptyTree,
defaultValue)
val unchecked = selector.annotated(New(ref(defn.UncheckedAnnot.typeRef)))
cpy.Match(tree)(unchecked, cases :+ defaultCase)
.subst(param.symbol :: Nil, pfParam :: Nil)
// Needed because a partial function can be written as:
// param => param match { case "foo" if foo(param) => param }
// And we need to update all references to 'param'
}

def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = {
val tru = Literal(Constant(true))
def translateCase(cdef: CaseDef) =
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
val paramRef = paramRefss.head.head
val defaultValue = Literal(Constant(false))
translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
}

def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = {
val List(paramRef, defaultRef) = paramRefss(1)
def translateCase(cdef: CaseDef) =
cdef.changeOwner(anonSym, applyOrElseFn)
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
}

val constr = newConstructor(pfSym, Synthetic, Nil, Nil).entered
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef))
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))
}

private def checkRefinements(tpe: Type, tree: Tree)(using Context): Type = tpe.dealias match {
Expand Down
12 changes: 0 additions & 12 deletions tests/neg/i4241.scala

This file was deleted.

24 changes: 24 additions & 0 deletions tests/run/i4241.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
object Test extends App {
val a: PartialFunction[Int, Int] = { case x => x }
val b: PartialFunction[Int, Int] = x => x match { case 1 => 1; case 2 => 2 }
val c: PartialFunction[Int, Int] = x => { x match { case 1 => 1 } }
val d: PartialFunction[Int, Int] = x => { { x match { case 1 => 1 } } }

val e: PartialFunction[Int, Int] = x => { println("foo"); x match { case 1 => 1 } }
val f: PartialFunction[Int, Int] = x => x
val g: PartialFunction[Int, String] = { x => x.toString }
val h: PartialFunction[Int, String] = _.toString
assert(a.isDefinedAt(2))
assert(b.isDefinedAt(2))
assert(!b.isDefinedAt(3))
assert(c.isDefinedAt(1))
assert(!c.isDefinedAt(2))
assert(d.isDefinedAt(1))
assert(!d.isDefinedAt(2))
assert(e.isDefinedAt(2))
assert(f.isDefinedAt(2))
assert(g.isDefinedAt(2))
assert(h.isDefinedAt(2))
}


0 comments on commit aa5ed33

Please sign in to comment.