diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 684b8faa13ea..a5a34250d89c 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -1651,6 +1651,24 @@ object Types { case _ => resultType } + /** Find the function type in union. + * If there are multiple function types, NoType is returned. + */ + def findFunctionTypeInUnion(using Context): Type = this match { + case t: OrType => + val t1 = t.tp1.findFunctionTypeInUnion + if t1 == NoType then t.tp2.findFunctionTypeInUnion else + val t2 = t.tp2.findFunctionTypeInUnion + // Returen NoType if the union contains multiple function types + if t2 == NoType then t1 else NoType + case t if defn.isNonRefinedFunction(t) => + t + case t @ SAMType(_) => + t + case _ => + NoType + } + /** This type seen as a TypeBounds */ final def bounds(using Context): TypeBounds = this match { case tp: TypeBounds => tp diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 17bb2f9dd2df..bce9aff02be1 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1125,6 +1125,7 @@ class Typer extends Namer newTypeVar(apply(bounds.orElse(TypeBounds.empty)).bounds) case _ => mapOver(t) } + val pt1 = pt.stripTypeVar.dealias if (pt1 ne pt1.dropDependentRefinement) && defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType) @@ -1133,22 +1134,25 @@ class Typer extends Namer i"""Implementation restriction: Expected result type $pt1 |is a curried dependent context function type. Such types are not yet supported.""", tree.srcPos) + pt1 match { - case pt1 if defn.isNonRefinedFunction(pt1) => - // if expected parameter type(s) are wildcards, approximate from below. - // if expected result type is a wildcard, approximate from above. - // this can type the greatest set of admissible closures. - (pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last))) - case SAMType(sam @ MethodTpe(_, formals, restpe)) => - (formals, - if (sam.isResultDependent) - untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef))) - else - typeTree(restpe)) case tp: TypeParamRef => decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree) - case _ => - (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree()) + case _ => pt1.findFunctionTypeInUnion match { + case pt1 if defn.isNonRefinedFunction(pt1) => + // if expected parameter type(s) are wildcards, approximate from below. + // if expected result type is a wildcard, approximate from above. + // this can type the greatest set of admissible closures. + (pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last))) + case SAMType(sam @ MethodTpe(_, formals, restpe)) => + (formals, + if sam.isResultDependent then + untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef))) + else + typeTree(restpe)) + case _ => + (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree()) + } } } @@ -1399,7 +1403,7 @@ class Typer extends Namer if (tree.tpt.isEmpty) meth1.tpe.widen match { case mt: MethodType => - pt.stripNull match { + pt.findFunctionTypeInUnion match { case pt @ SAMType(sam) if !defn.isFunctionType(pt) && mt <:< sam => // SAMs of the form C[?] where C is a class cannot be conversion targets. diff --git a/tests/explicit-nulls/pos/i11694.scala b/tests/explicit-nulls/pos/i11694.scala new file mode 100644 index 000000000000..8098775a8430 --- /dev/null +++ b/tests/explicit-nulls/pos/i11694.scala @@ -0,0 +1,4 @@ +def test = { + val x = new java.util.ArrayList[String]() + val y = x.stream().nn.filter(s => s.nn.length > 0) +} \ No newline at end of file diff --git a/tests/neg/i11694.scala b/tests/neg/i11694.scala new file mode 100644 index 000000000000..5bbad1a83ce2 --- /dev/null +++ b/tests/neg/i11694.scala @@ -0,0 +1,19 @@ +def test1 = { + def f11: (Int => Int) | Unit = x => x + 1 + def f12: Null | (Int => Int) = x => x + 1 + + def f21: (Int => Int) | Null = x => x + 1 + def f22: Null | (Int => Int) = x => x + 1 +} + +def test2 = { + def f1: (Int => String) | (Int => Int) | Null = x => x + 1 // error + def f2: (Int => String) | Function[String, Int] | Null = x => "" + x // error + def f3: Function[Int, Int] | Function[String, Int] | Null = x => x + 1 // error +} + +def test3 = { + import java.util.function.Function + val f1: Function[String, Int] | Unit = x => x.length + val f2: Function[String, Int] | Null = x => x.length +}