Skip to content

Commit

Permalink
infer: Use GADT constraints in maximiseType
Browse files Browse the repository at this point in the history
Consider the GADT constraints during Inferencing's maximiseType to avoid
instantiating type variables that lead to GADT casting inserting unsound
casts.
  • Loading branch information
dwijnand committed Jun 30, 2022
1 parent c9dded7 commit ebb5c6d
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 7 deletions.
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ sealed abstract class GadtConstraint extends Showable {
/** See [[ConstraintHandling.approximation]] */
def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type

def symbols: List[Symbol]

def fresh: GadtConstraint

/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
Expand Down Expand Up @@ -209,6 +211,8 @@ final class ProperGadtConstraint private(
res
}

override def symbols: List[Symbol] = mapping.keys

override def fresh: GadtConstraint = new ProperGadtConstraint(
myConstraint,
mapping,
Expand Down Expand Up @@ -307,6 +311,8 @@ final class ProperGadtConstraint private(

override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")

override def symbols: List[Symbol] = Nil

override def fresh = new ProperGadtConstraint
override def restore(other: GadtConstraint): Unit =
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")
Expand Down
12 changes: 8 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import collection.mutable

import scala.annotation.internal.sharable

import config.Printers.gadts

object Inferencing {

import tpd._
Expand Down Expand Up @@ -411,10 +409,16 @@ object Inferencing {
Stats.record("maximizeType")
val vs = variances(tp)
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
vs foreachBinding { (tvar, v) =>
if !tvar.isInstantiated then
if (v == 1) tvar.instantiate(fromBelow = false)
else if (v == -1) tvar.instantiate(fromBelow = true)
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
// then check the tvar doesn't occur in the opposite GADT bound (lower/upper) within any of the GADT bounds
// if it doesn't occur then it's safe to instantiate the tvar
// Eg neg/i14983 the C in Node[+C] is in the GADT lower bound X >: List[C] so maximising to Node[Any] is unsound
// Eg pos/precise-pattern-type the T in Tree[-T] is in no GADT upper bound so can maximise to Tree[Type]
val safeToInstantiate = v != 0 && gadtBounds.forall(tb => !tvar.occursIn(if v == 1 then tb.lo else tb.hi))
if safeToInstantiate then tvar.instantiate(fromBelow = v == -1)
else {
val bounds = TypeComparer.fullBounds(tvar.origin)
if bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) then
Expand Down
4 changes: 1 addition & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3762,9 +3762,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
res
} =>
// Insert an explicit cast, so that -Ycheck in later phases succeeds.
// I suspect, but am not 100% sure that this might affect inferred types,
// if the expected type is a supertype of the GADT bound. It would be good to come
// up with a test case for this.
// The check "safeToInstantiate" in `maximizeType` works to prevent unsound GADT casts.
val target =
if tree.tpe.isSingleton then
val conj = AndType(tree.tpe, pt)
Expand Down
15 changes: 15 additions & 0 deletions tests/neg/i14983.contra.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
sealed trait Show[-A]
final case class Pure[-B](showB: B => String) extends Show[B]
final case class Many[-C](showL: List[C] => String) extends Show[List[C]]

object Test:
def meth[X](show: Show[X]): X => String = show match
case Pure(showB) => showB
case Many(showL) =>
val res = (xs: List[String]) => xs.head.length.toString
res // error: Found: List[String] => String Required: X => String where: X is a type in method meth with bounds <: List[C$1]

def main(args: Array[String]): Unit =
val show = Many((is: List[Int]) => (is.head + 1).toString)
val fn = meth(show)
assert(fn(List(42)) == "43") // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String
22 changes: 22 additions & 0 deletions tests/neg/i14983.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
sealed trait Tree[+A]
final case class Leaf[+B](b: B) extends Tree[B]
final case class Node[+C](l: List[C]) extends Tree[List[C]]

// The original test case, minimised.
object Test:
def meth[X](tree: Tree[X]): X = tree match
case Leaf(v) => v // ok: Tree[X] vs Leaf[B], PTC: X >: B, max: Leaf[B] => Leaf[X], x: X
case Node(x) =>
// tree: Tree[X] vs Node[C] aka Tree[List[C]]
// PTC: X >: List[C]
// max: Node[C] => Node[Any], instantiating C := Any, which makes X >: List[Any]
// adapt: List[String] <: X = OKwithGADTUsed; insert GADT cast asInstanceOf[X]
List("boom") // error: Found: List[String] Required: X where: X is a type in method meth with bounds >: List[C$1]
// after fix:
// max: Node[C] => Node[C$1], instantiating C := C$1, a new symbol, so X >: List[C$1]
// adapt: List[String] <: X = Fail, because String !<: C$1

def main(args: Array[String]): Unit =
val tree = Node(List(42))
val res = meth(tree)
assert(res.head == 42) // was: ClassCastException: class java.lang.String cannot be cast to class java.lang.Integer
14 changes: 14 additions & 0 deletions tests/run/i14983.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
sealed trait Tree[+A]
final case class Leaf[+B](b: B) extends Tree[B]
final case class Node[+C](l: List[C]) extends Tree[List[C]]

// A version of the original test case that is sound so should typecheck.
object Test:
def meth[X](tree: Tree[X]): X = tree match
case Leaf(v) => v // ok: Tree[X] vs Leaf[B], PTC: X >: B, max: Leaf[B] => Leaf[X], x: X <:< X
case Node(x) => x // ok: Tree[X] vs Node[C], PTC: X >: List[C], max: Node[C] => Node[C$1], x: C$1 <:< X, w/ GADT cast

def main(args: Array[String]): Unit =
val tree = Node(List(42))
val res = meth(tree)
assert(res.head == 42) // ok

0 comments on commit ebb5c6d

Please sign in to comment.