Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support type-heavy pattern matches #12549

Merged
merged 9 commits into from
Jun 3, 2021
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2586,6 +2586,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
provablyDisjoint(tp1, gadtBounds(tp2.symbol).hi) || provablyDisjoint(tp1, tp2.superType)
case (tp1: TermRef, tp2: TermRef) if isEnumValueOrModule(tp1) && isEnumValueOrModule(tp2) =>
tp1.termSymbol != tp2.termSymbol
case (tp1: TermRef, tp2: TypeRef) if isEnumValueOrModule(tp1) && !tp1.classSymbols.exists(_.derivesFrom(tp2.classSymbol)) =>
// Note: enum values may have multiple parents
true
case (tp1: TypeRef, tp2: TermRef) if isEnumValueOrModule(tp2) && !tp2.classSymbols.exists(_.derivesFrom(tp1.classSymbol)) =>
true
case (tp1: Type, tp2: Type) if defn.isTupleType(tp1) =>
provablyDisjoint(tp1.toNestedPairs, tp2)
case (tp1: Type, tp2: Type) if defn.isTupleType(tp2) =>
Expand Down
21 changes: 15 additions & 6 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -692,15 +692,24 @@ object TypeOps:
*/
private def instantiateToSubType(tp1: NamedType, tp2: Type)(using Context): Type = {
// In order for a child type S to qualify as a valid subtype of the parent
// T, we need to test whether it is possible S <: T. Therefore, we replace
// type parameters in T with tvars, and see if the subtyping is true.
val approximateTypeParams = new TypeMap {
// T, we need to test whether it is possible S <: T.
//
// The check is different from subtype checking due to type parameters and
// `this`. We perform the following operations to approximate the parameters:
//
// 1. Replace type parameters in T with tvars
// 2. Replace `A.this.C` with `A#C` (see tests/patmat/i12681.scala)
//
val approximateParent = new TypeMap {
val boundTypeParams = util.HashMap[TypeRef, TypeVar]()

def apply(tp: Type): Type = tp.dealias match {
case _: MatchType =>
tp // break cycles

case ThisType(tref: TypeRef) if !tref.symbol.isStaticOwner =>
tref

case tp: TypeRef if !tp.symbol.isClass =>
def lo = LazyRef.of(apply(tp.underlying.loBound))
def hi = LazyRef.of(apply(tp.underlying.hiBound))
Expand Down Expand Up @@ -787,7 +796,7 @@ object TypeOps:
// we manually patch subtyping check instead of changing TypeComparer.
// See tests/patmat/i3645b.scala
def parentQualify(tp1: Type, tp2: Type) = tp1.classSymbol.info.parents.exists { parent =>
parent.argInfos.nonEmpty && approximateTypeParams(parent) <:< tp2
parent.argInfos.nonEmpty && approximateParent(parent) <:< tp2
}

def instantiate(): Type = {
Expand All @@ -797,8 +806,8 @@ object TypeOps:

if (protoTp1 <:< tp2) instantiate()
else {
val protoTp2 = approximateTypeParams(tp2)
if (protoTp1 <:< protoTp2 || parentQualify(protoTp1, protoTp2)) instantiate()
val approxTp2 = approximateParent(tp2)
if (protoTp1 <:< approxTp2 || parentQualify(protoTp1, approxTp2)) instantiate()
else NoType
}
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ object SymUtils:
def whyNotGenericSum(declScope: Symbol)(using Context): String =
if (!self.is(Sealed))
s"it is not a sealed ${self.kindString}"
else if (!self.isOneOf(AbstractOrTrait))
s"it is not an abstract class"
else {
val children = self.children
val companionMirror = self.useCompanionAsMirror
Expand Down
36 changes: 26 additions & 10 deletions compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ trait SpaceLogic {
def signature(unapp: TermRef, scrutineeTp: Type, argLen: Int): List[Type]

/** Get components of decomposable types */
def decompose(tp: Type): List[Space]
def decompose(tp: Type): List[Typ]

/** Whether the extractor covers the given type */
def covers(unapp: TermRef, scrutineeTp: Type): Boolean
Expand Down Expand Up @@ -176,6 +176,8 @@ trait SpaceLogic {
ss.forall(isSubspace(_, b))
case (Typ(tp1, _), Typ(tp2, _)) =>
isSubType(tp1, tp2)
|| canDecompose(tp1) && tryDecompose1(tp1)
|| canDecompose(tp2) && tryDecompose2(tp2)
case (Typ(tp1, _), Or(ss)) => // optimization: don't go to subtraction too early
ss.exists(isSubspace(a, _)) || tryDecompose1(tp1)
case (_, Or(_)) =>
Expand Down Expand Up @@ -337,9 +339,7 @@ class SpaceEngine(using Context) extends SpaceLogic {
val res = TypeComparer.provablyDisjoint(tp1, tp2)

if (res) Empty
else if (tp1.isSingleton) Typ(tp1, true)
else if (tp2.isSingleton) Typ(tp2, true)
else Typ(AndType(tp1, tp2), true)
else Typ(AndType(tp1, tp2), decomposed = true)
}
}

Expand Down Expand Up @@ -591,14 +591,27 @@ class SpaceEngine(using Context) extends SpaceLogic {
}

/** Decompose a type into subspaces -- assume the type can be decomposed */
def decompose(tp: Type): List[Space] =
def decompose(tp: Type): List[Typ] =
tp.dealias match {
case AndType(tp1, tp2) =>
intersect(Typ(tp1, false), Typ(tp2, false)) match {
case Or(spaces) => spaces.toList
case Empty => Nil
case space => List(space)
}
def decomposeComponent(tpA: Type, tpB: Type): List[Typ] =
decompose(tpA).flatMap {
case Typ(tp, _) =>
if tp <:< tpB then
Typ(tp, decomposed = true) :: Nil
else if tpB <:< tp then
Typ(tpB, decomposed = true) :: Nil
else if TypeComparer.provablyDisjoint(tp, tpB) then
Nil
else
Typ(AndType(tp, tpB), decomposed = true) :: Nil
}

if canDecompose(tp1) then
decomposeComponent(tp1, tp2)
else
decomposeComponent(tp2, tp1)

case OrType(tp1, tp2) => List(Typ(tp1, true), Typ(tp2, true))
case tp if tp.isRef(defn.BooleanClass) =>
List(
Expand Down Expand Up @@ -833,6 +846,9 @@ class SpaceEngine(using Context) extends SpaceLogic {

if (!exhaustivityCheckable(sel)) return

debug.println("checking " + _match.show)
debug.println("selTyp = " + selTyp.show)

val patternSpace = Or(cases.foldLeft(List.empty[Space]) { (acc, x) =>
val space = if (x.guard.isEmpty) project(x.pat) else Empty
debug.println(s"${x.pat.show} ====> ${show(space)}")
Expand Down
13 changes: 3 additions & 10 deletions scaladoc/src/dotty/tools/scaladoc/tasty/TypesSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@ trait TypesSupport:
case ThisType(tpe) => inner(tpe)
case AnnotatedType(tpe, _) => inner(tpe)
case AppliedType(tpe, _) => inner(tpe)
case tp @ TermRef(qual, typeName) =>
qual match
case _: TypeRepr | _: NoPrefix => Some(tp.termSymbol)
case other => None
case tp @ TypeRef(qual, typeName) =>
qual match
case _: TypeRepr | _: NoPrefix => Some(tp.typeSymbol)
case other => None
case tp @ TermRef(qual, typeName) => Some(tp.termSymbol)
case tp @ TypeRef(qual, typeName) => Some(tp.typeSymbol)

val typeSymbol = extractTypeSymbol(method.returnTpt)

Expand Down Expand Up @@ -204,8 +198,7 @@ trait TypesSupport:
case tp @ TypeRef(qual, typeName) =>
qual match {
case r: RecursiveThis => texts(s"this.$typeName")
case _: TypeRepr | _: NoPrefix => link(tp.typeSymbol)
case other => noSupported(s"TypeRepr: $tp")
case _: TypeRepr => link(tp.typeSymbol)
}
// convertTypeOrBoundsToReference(reflect)(qual) match {
// case TypeReference(label, link, xs, _) => TypeReference(typeName, link + "/" + label, xs, true)
Expand Down
15 changes: 15 additions & 0 deletions tests/patmat/i10667.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
sealed trait A

enum Nums {
case One
case Two extends Nums with A
case Three
}

object Test {
val list = List[Nums & A](Nums.Two)

list.map {
case Nums.Two => ()
}
}
20 changes: 20 additions & 0 deletions tests/patmat/i12475.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
sealed trait Ty {
type T
}

class TUnit() extends Ty {
type T = Unit
}

case object TUnit extends TUnit()

final case class TFun(dom: Ty, cod: Ty) extends Ty {
type T = dom.T => cod.T
}

def default(ty: Ty): ty.T = (ty: ty.type & Ty) match {
case a: (ty.type & TUnit) => (): a.T
case a: (ty.type & TFun) =>
val f = { (x: a.dom.T) => default(a.cod) }
f: a.T
}
14 changes: 14 additions & 0 deletions tests/patmat/i12475b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
trait SomeRestriction

enum ADT {
case A
case B extends ADT with SomeRestriction
}

object MinimalExample {
val b: ADT & SomeRestriction = ADT.B

b match {
case ADT.B => ???
}
}
14 changes: 14 additions & 0 deletions tests/patmat/i12546.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
trait SomeRestriction

enum ADT {
case A extends ADT
case B extends ADT with SomeRestriction
}

object MinimalExample {
val b: ADT & SomeRestriction = ADT.B

b match {
case ADT.B => ???
}
}
2 changes: 2 additions & 0 deletions tests/patmat/i12559.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
10: Match case Unreachable
27: Match case Unreachable
35 changes: 35 additions & 0 deletions tests/patmat/i12559.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package akka.event

object TestA:
sealed trait LogEvent

object LogEvent:
def myOrdinal(e: LogEvent): Int = e match
case e: Error => 0
// case e: Warning => 1
case e: LogEventWithMarker => 2


class Error() extends LogEvent
class Error2() extends Error() with LogEventWithMarker

// case class Warning() extends LogEvent

sealed trait LogEventWithMarker extends LogEvent

object TestB:
sealed trait LogEvent

object LogEvent:
def myOrdinal(e: LogEvent): Int = e match
case e: Error => 0
case e: Warning => 1
case e: LogEventWithMarker => 2


case class Error() extends LogEvent
class Error2() extends Error() with LogEventWithMarker

case class Warning() extends LogEvent

sealed trait LogEventWithMarker extends LogEvent
2 changes: 2 additions & 0 deletions tests/patmat/i12602.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
sealed class Foo[T]
object Foo extends Foo[Nothing]
20 changes: 20 additions & 0 deletions tests/patmat/i12681.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
object Examples {

case class Leaf1() extends i.Root
case class Leaf2() extends i.Branch

val i = new Inner()

class Inner {

sealed trait Root
sealed trait Branch extends Root

// simulate ordinal method of a Mirror.SumOf generated at this call site
def myOrdinal(r: Root): Int = r match {
case _: Examples.Leaf1 => 0
case _: Inner.this.Branch => 1
}
}

}