Skip to content

Commit

Permalink
Merge pull request #15625 from dotty-staging/fix-15618
Browse files Browse the repository at this point in the history
Fix two problems related to match types as array elements
  • Loading branch information
odersky authored Jul 10, 2022
2 parents 794e7c9 + 8baaeae commit 11d65aa
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 25 deletions.
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ object TypeErasure {
isGenericArrayElement(tp.alias, isScala2)
case tp: TypeBounds =>
!fitsInJVMArray(tp.hi)
case tp: MatchType =>
val alts = tp.alternatives
alts.nonEmpty && !fitsInJVMArray(alts.reduce(OrType(_, _, soft = true)))
case tp: TypeProxy =>
isGenericArrayElement(tp.translucentSuperType, isScala2)
case tp: AndType =>
Expand Down
12 changes: 7 additions & 5 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,13 @@ object Types {

/** Is this a match type or a higher-kinded abstraction of one?
*/
def isMatch(using Context): Boolean = stripped match {
case _: MatchType => true
case tp: HKTypeLambda => tp.resType.isMatch
case tp: AppliedType => tp.isMatchAlias
case _ => false
def isMatch(using Context): Boolean = underlyingMatchType.exists

def underlyingMatchType(using Context): Type = stripped match {
case tp: MatchType => tp
case tp: HKTypeLambda => tp.resType.underlyingMatchType
case tp: AppliedType if tp.isMatchAlias => tp.superType.underlyingMatchType
case _ => NoType
}

/** Is this a higher-kinded type lambda with given parameter variances? */
Expand Down
49 changes: 29 additions & 20 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,35 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
private type SpecialHandlers = List[(ClassSymbol, SpecialHandler)]

val synthesizedClassTag: SpecialHandler = (formal, span) =>
formal.argInfos match
case arg :: Nil =>
if isFullyDefined(arg, ForceDegree.all) then
arg match
case defn.ArrayOf(elemTp) =>
val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span)
if etag.tpe.isError then EmptyTreeNoError else withNoErrors(etag.select(nme.wrap))
case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) =>
val sym = tp.typeSymbol
val classTag = ref(defn.ClassTagModule)
val tag =
if defn.SpecialClassTagClasses.contains(sym) then
classTag.select(sym.name.toTermName)
else
val clsOfType = escapeJavaArray(erasure(tp))
classTag.select(nme.apply).appliedToType(tp).appliedTo(clsOf(clsOfType))
withNoErrors(tag.withSpan(span))
case tp => EmptyTreeNoError
else EmptyTreeNoError
case _ => EmptyTreeNoError
val tag = formal.argInfos match
case arg :: Nil if isFullyDefined(arg, ForceDegree.all) =>
arg match
case defn.ArrayOf(elemTp) =>
val etag = typer.inferImplicitArg(defn.ClassTagClass.typeRef.appliedTo(elemTp), span)
if etag.tpe.isError then EmptyTree else etag.select(nme.wrap)
case tp if hasStableErasure(tp) && !defn.isBottomClassAfterErasure(tp.typeSymbol) =>
val sym = tp.typeSymbol
val classTagModul = ref(defn.ClassTagModule)
if defn.SpecialClassTagClasses.contains(sym) then
classTagModul.select(sym.name.toTermName).withSpan(span)
else
def clsOfType(tp: Type): Type = tp.dealias.underlyingMatchType match
case matchTp: MatchType =>
matchTp.alternatives.map(clsOfType) match
case ct1 :: cts if cts.forall(ct1 == _) => ct1
case _ => NoType
case _ =>
escapeJavaArray(erasure(tp))
val ctype = clsOfType(tp)
if ctype.exists then
classTagModul.select(nme.apply)
.appliedToType(tp)
.appliedTo(clsOf(ctype))
.withSpan(span)
else EmptyTree
case _ => EmptyTree
case _ => EmptyTree
(tag, Nil)
end synthesizedClassTag

val synthesizedTypeTest: SpecialHandler =
Expand Down
18 changes: 18 additions & 0 deletions tests/neg/i15618.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-- Error: tests/neg/i15618.scala:17:44 ---------------------------------------------------------------------------------
17 | def toArray: Array[ScalaType[T]] = Array() // error
| ^
| No ClassTag available for ScalaType[T]
|
| where: T is a type in class Tensor with bounds <: DType
|
|
| Note: a match type could not be fully reduced:
|
| trying to reduce ScalaType[T]
| failed since selector T
| does not match case Float16 => Float
| and cannot be shown to be disjoint from it either.
| Therefore, reduction cannot advance to the remaining cases
|
| case Float32 => Float
| case Int32 => Int
23 changes: 23 additions & 0 deletions tests/neg/i15618.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
sealed abstract class DType
sealed class Float16 extends DType
sealed class Float32 extends DType
sealed class Int32 extends DType

object Float16 extends Float16
object Float32 extends Float32
object Int32 extends Int32

type ScalaType[U <: DType] <: Int | Float = U match
case Float16 => Float
case Float32 => Float
case Int32 => Int

class Tensor[T <: DType](dtype: T):
def toSeq: Seq[ScalaType[T]] = Seq()
def toArray: Array[ScalaType[T]] = Array() // error

@main
def Test =
val t = Tensor(Float32) // Tensor[Float32]
println(t.toSeq.headOption) // works, Seq[Float]
println(t.toArray.headOption) // ClassCastException
1 change: 1 addition & 0 deletions tests/run/i15618.check
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Some(1)
24 changes: 24 additions & 0 deletions tests/run/i15618.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
sealed abstract class DType
sealed class Float16 extends DType
sealed class Float32 extends DType
sealed class Int32 extends DType

object Float16 extends Float16
object Float32 extends Float32
object Int32 extends Int32

type ScalaType[U <: DType] <: Int | Float = U match
case Float16 => Float
case Float32 => Float
case Int32 => Int

abstract class Tensor[T <: DType]:
def toArray: Array[ScalaType[T]]

object IntTensor extends Tensor[Int32]:
def toArray: Array[Int] = Array(1, 2, 3)

@main
def Test =
val t = IntTensor: Tensor[Int32]
println(t.toArray.headOption) // was ClassCastException

0 comments on commit 11d65aa

Please sign in to comment.