Skip to content

Commit

Permalink
Fix i11694: extract function type and SAM in union type
Browse files Browse the repository at this point in the history
  • Loading branch information
noti0na1 committed Apr 1, 2021
1 parent a288432 commit 9652826
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 10 deletions.
45 changes: 35 additions & 10 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,20 @@ class Typer extends Namer
newTypeVar(apply(bounds.orElse(TypeBounds.empty)).bounds)
case _ => mapOver(t)
}
def extractInUnion(t: Type): Seq[Type] = t match {
case t: OrType =>
extractInUnion(t.tp1) ++ extractInUnion(t.tp2)
case t: TypeParamRef =>
extractInUnion(ctx.typerState.constraint.entry(t).bounds.hi)
case t if defn.isNonRefinedFunction(t) =>
Seq(t)
case SAMType(_: MethodType) =>
Seq(t)
case _ =>
Nil
}
def defaultResult = (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())

val pt1 = pt.stripTypeVar.dealias
if (pt1 ne pt1.dropDependentRefinement)
&& defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType)
Expand All @@ -1115,22 +1129,25 @@ class Typer extends Namer
i"""Implementation restriction: Expected result type $pt1
|is a curried dependent context function type. Such types are not yet supported.""",
tree.srcPos)
pt1 match {

val elems = extractInUnion(pt1)
if elems.length != 1 then
// The union type containing multiple function types is ignored
defaultResult
else elems.head match {
case pt1 if defn.isNonRefinedFunction(pt1) =>
// if expected parameter type(s) are wildcards, approximate from below.
// if expected result type is a wildcard, approximate from above.
// this can type the greatest set of admissible closures.
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
case SAMType(sam @ MethodTpe(_, formals, restpe)) =>
(formals,
if (sam.isResultDependent)
untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef)))
else
typeTree(restpe))
case tp: TypeParamRef =>
decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree)
if (sam.isResultDependent)
untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef)))
else
typeTree(restpe))
case _ =>
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
defaultResult
}
}

Expand Down Expand Up @@ -1375,14 +1392,22 @@ class Typer extends Namer
}

def typedClosure(tree: untpd.Closure, pt: Type)(using Context): Tree = {
def extractInUnion(t: Type): Seq[Type] = t match {
case t: OrType =>
extractInUnion(t.tp1) ++ extractInUnion(t.tp2)
case SAMType(_) =>
Seq(t)
case _ =>
Nil
}
val env1 = tree.env mapconserve (typed(_))
val meth1 = typedUnadapted(tree.meth)
val target =
if (tree.tpt.isEmpty)
meth1.tpe.widen match {
case mt: MethodType =>
pt.stripNull match {
case pt @ SAMType(sam)
extractInUnion(pt) match {
case Seq(pt @ SAMType(sam))
if !defn.isFunctionType(pt) && mt <:< sam =>
// SAMs of the form C[?] where C is a class cannot be conversion targets.
// The resulting class `class $anon extends C[?] {...}` would be illegal,
Expand Down
4 changes: 4 additions & 0 deletions tests/explicit-nulls/pos/i11694.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = {
val x = new java.util.ArrayList[String]()
val y = x.stream().nn.filter(s => s.nn.length > 0)
}
19 changes: 19 additions & 0 deletions tests/neg/i11694.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
def test1 = {
def f11: (Int => Int) | Unit = x => x + 1
def f12: Null | (Int => Int) = x => x + 1

def f21: (Int => Int) | Null = x => x + 1
def f22: Null | (Int => Int) = x => x + 1
}

def test2 = {
def f1: (Int => String) | (Int => Int) | Null = x => x + 1 // error
def f2: (Int => String) | Function[String, Int] | Null = x => "" + x // error
def f3: Function[Int, Int] | Function[String, Int] | Null = x => x + 1 // error
}

def test3 = {
import java.util.function.Function
val f1: Function[String, Int] | Unit = x => x.length
val f2: Function[String, Int] | Null = x => x.length
}

0 comments on commit 9652826

Please sign in to comment.