Skip to content

Commit

Permalink
Fix i11694: extract function type and SAM in union type
Browse files Browse the repository at this point in the history
  • Loading branch information
noti0na1 committed Jun 3, 2021
1 parent 92c75ab commit 1df84ec
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 14 deletions.
18 changes: 18 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,24 @@ object Types {
case _ => resultType
}

/** Find the function type in union.
* If there are multiple function types, NoType is returned.
*/
def findFuntionTypeInUnion(using Context): Type = this match {
case t: OrType =>
val t1 = t.tp1.findFuntionTypeInUnion
if t1 == NoType then t.tp2.findFuntionTypeInUnion else
val t2 = t.tp2.findFuntionTypeInUnion
// Returen NoType if the union contains multiple function types
if t2 == NoType then t1 else NoType
case t if defn.isNonRefinedFunction(t) =>
t
case t @ SAMType(_: MethodType) =>
t
case _ =>
NoType
}

/** This type seen as a TypeBounds */
final def bounds(using Context): TypeBounds = this match {
case tp: TypeBounds => tp
Expand Down
32 changes: 18 additions & 14 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,7 @@ class Typer extends Namer
newTypeVar(apply(bounds.orElse(TypeBounds.empty)).bounds)
case _ => mapOver(t)
}

val pt1 = pt.stripTypeVar.dealias
if (pt1 ne pt1.dropDependentRefinement)
&& defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType)
Expand All @@ -1133,22 +1134,25 @@ class Typer extends Namer
i"""Implementation restriction: Expected result type $pt1
|is a curried dependent context function type. Such types are not yet supported.""",
tree.srcPos)

pt1 match {
case pt1 if defn.isNonRefinedFunction(pt1) =>
// if expected parameter type(s) are wildcards, approximate from below.
// if expected result type is a wildcard, approximate from above.
// this can type the greatest set of admissible closures.
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
case SAMType(sam @ MethodTpe(_, formals, restpe)) =>
(formals,
if (sam.isResultDependent)
untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef)))
else
typeTree(restpe))
case tp: TypeParamRef =>
decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree)
case _ =>
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
case _ => pt1.findFuntionTypeInUnion match {
case pt1 if defn.isNonRefinedFunction(pt1) =>
// if expected parameter type(s) are wildcards, approximate from below.
// if expected result type is a wildcard, approximate from above.
// this can type the greatest set of admissible closures.
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
case SAMType(sam @ MethodTpe(_, formals, restpe)) =>
(formals,
if sam.isResultDependent then
untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef)))
else
typeTree(restpe))
case _ =>
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
}
}
}

Expand Down Expand Up @@ -1399,7 +1403,7 @@ class Typer extends Namer
if (tree.tpt.isEmpty)
meth1.tpe.widen match {
case mt: MethodType =>
pt.stripNull match {
pt.findFuntionTypeInUnion match {
case pt @ SAMType(sam)
if !defn.isFunctionType(pt) && mt <:< sam =>
// SAMs of the form C[?] where C is a class cannot be conversion targets.
Expand Down
4 changes: 4 additions & 0 deletions tests/explicit-nulls/pos/i11694.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = {
val x = new java.util.ArrayList[String]()
val y = x.stream().nn.filter(s => s.nn.length > 0)
}
19 changes: 19 additions & 0 deletions tests/neg/i11694.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
def test1 = {
def f11: (Int => Int) | Unit = x => x + 1
def f12: Null | (Int => Int) = x => x + 1

def f21: (Int => Int) | Null = x => x + 1
def f22: Null | (Int => Int) = x => x + 1
}

def test2 = {
def f1: (Int => String) | (Int => Int) | Null = x => x + 1 // error
def f2: (Int => String) | Function[String, Int] | Null = x => "" + x // error
def f3: Function[Int, Int] | Function[String, Int] | Null = x => x + 1 // error
}

def test3 = {
import java.util.function.Function
val f1: Function[String, Int] | Unit = x => x.length
val f2: Function[String, Int] | Null = x => x.length
}

0 comments on commit 1df84ec

Please sign in to comment.