From 2dd1fa30cc24ce046ed7dc3c8bf6e4fdd51b8edf Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 6 Feb 2024 15:45:15 -0500 Subject: [PATCH] Generalise `SSubsetStruct` to views on `SBaseStruct`s --- .../src/main/scala/is/hail/expr/ir/Emit.scala | 2 +- .../is/hail/types/physical/PBaseStruct.scala | 2 +- .../hail/types/physical/PSubsetStruct.scala | 8 +- .../is/hail/types/physical/stypes/SCode.scala | 4 + .../stypes/concrete/SStackStruct.scala | 2 +- .../stypes/concrete/SStructView.scala | 157 ++++++++++++++++++ .../stypes/concrete/SSubsetStruct.scala | 144 ---------------- .../stypes/interfaces/SBaseStruct.scala | 2 +- .../is/hail/types/virtual/TBaseStruct.scala | 3 +- .../stypes/concrete/SStructViewSuite.scala | 60 +++++++ .../stypes/concrete/SSubsetStructSuite.scala | 43 ----- 11 files changed, 231 insertions(+), 196 deletions(-) create mode 100644 hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStructView.scala delete mode 100644 hail/src/main/scala/is/hail/types/physical/stypes/concrete/SSubsetStruct.scala create mode 100644 hail/src/test/scala/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala delete mode 100644 hail/src/test/scala/is/hail/types/physical/stypes/concrete/SSubsetStructSuite.scala diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index 06e872a57b0..812cb2c3ee2 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -1122,7 +1122,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { iec.map(cb)(pc => cast(cb, pc)) case CastRename(v, _typ) => emitI(v) - .map(cb)(pc => pc.st.castRename(_typ).fromValues(pc.valueTuple)) + .map(cb)(_.castRename(_typ)) case NA(typ) => IEmitCode.missing(cb, SUnreachable.fromVirtualType(typ).defaultValue) case IsNA(v) => diff --git a/hail/src/main/scala/is/hail/types/physical/PBaseStruct.scala b/hail/src/main/scala/is/hail/types/physical/PBaseStruct.scala index 5e9c2a415a0..2a6c1e7ca23 100644 --- a/hail/src/main/scala/is/hail/types/physical/PBaseStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/PBaseStruct.scala @@ -43,7 +43,7 @@ abstract class PBaseStruct extends PType { def size: Int = fields.length - def isIsomorphicTo(other: PBaseStruct) = + def isIsomorphicTo(other: PBaseStruct): Boolean = this.fields.size == other.fields.size && this.isCompatibleWith(other) def _toPretty: String = { diff --git a/hail/src/main/scala/is/hail/types/physical/PSubsetStruct.scala b/hail/src/main/scala/is/hail/types/physical/PSubsetStruct.scala index ed873db7dd6..f575965c369 100644 --- a/hail/src/main/scala/is/hail/types/physical/PSubsetStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/PSubsetStruct.scala @@ -5,8 +5,8 @@ import is.hail.asm4s.{Code, Value} import is.hail.backend.HailStateManager import is.hail.expr.ir.EmitCodeBuilder import is.hail.types.physical.stypes.SValue -import is.hail.types.physical.stypes.concrete.SSubsetStruct -import is.hail.types.physical.stypes.interfaces.SBaseStructValue +import is.hail.types.physical.stypes.concrete.SStructView +import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructValue} import is.hail.types.virtual.TStruct import is.hail.utils._ @@ -137,8 +137,8 @@ final case class PSubsetStruct(ps: PStruct, _fieldNames: IndexedSeq[String]) ext ): Long = throw new UnsupportedOperationException - def sType: SSubsetStruct = - new SSubsetStruct(ps.sType, _fieldNames) + def sType: SBaseStruct = + SStructView.subset(_fieldNames, ps.sType) def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) : Value[Long] = diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala b/hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala index 3180ea74155..1f1b6040274 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala @@ -6,6 +6,7 @@ import is.hail.expr.ir.EmitCodeBuilder import is.hail.types.physical.stypes.concrete.SRNGStateValue import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ +import is.hail.types.virtual.Type object SCode { def add(cb: EmitCodeBuilder, left: SValue, right: SValue, required: Boolean): SValue = { @@ -114,6 +115,9 @@ trait SValue { throw new UnsupportedOperationException(s"Stype $st has no hashcode") def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value + + def castRename(t: Type): SValue = + st.castRename(t).fromValues(valueTuple) } trait SSettable extends SValue { diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackStruct.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackStruct.scala index c5c42a2bd43..9b7e3bcc9f3 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackStruct.scala @@ -178,7 +178,7 @@ class SStackStructValue(val st: SStackStruct, val values: IndexedSeq[EmitValue]) override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] = values(fieldIdx).m - override def subset(fieldNames: String*): SStackStructValue = { + override def subset(fieldNames: String*): SBaseStructValue = { val newToOld = fieldNames.map(st.fieldIdx).toArray val oldVType = st.virtualType.asInstanceOf[TStruct] val newVirtualType = TStruct(newToOld.map(i => (oldVType.fieldNames(i), oldVType.types(i))): _*) diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStructView.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStructView.scala new file mode 100644 index 00000000000..03541958f1f --- /dev/null +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStructView.scala @@ -0,0 +1,157 @@ +package is.hail.types.physical.stypes.concrete + +import is.hail.annotations.Region +import is.hail.asm4s.{Settable, TypeInfo, Value} +import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} +import is.hail.types.physical.{PCanonicalStruct, PType} +import is.hail.types.physical.stypes.{EmitType, SType, SValue} +import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue} +import is.hail.types.virtual.{TBaseStruct, TStruct, Type} + +object SStructView { + def subset(fieldnames: IndexedSeq[String], struct: SBaseStruct): SStructView = + struct match { + case s @ SStructView(parent, restrict, rename) => + SStructView( + parent, + restrict.typeAfterSelect(fieldnames.map(s.fieldIdx)), + rename.typeAfterSelectNames(fieldnames), + ) + + case s => + val restrict = s.virtualType.asInstanceOf[TStruct].typeAfterSelectNames(fieldnames) + SStructView(s, restrict, restrict) + } +} + +// A 'view' on `SBaseStruct`s, ie one that presents an upcast and/or renamed facade on another +final case class SStructView(parent: SBaseStruct, restrict: TStruct, rename: TStruct) + extends SBaseStruct { + + assert( + restrict.canCastTo(rename), + f"""Upcast operations are not isomorphic + | restrict: '${restrict._toPretty}' + | rename: '${rename._toPretty}' + |""".stripMargin, + ) + + override val size: Int = + restrict.fields.length + + lazy val newToOldFieldMapping: Map[Int, Int] = + restrict.fields.view.map(f => f.index -> parent.fieldIdx(f.name)).toMap + + override lazy val fieldTypes: IndexedSeq[SType] = + Array.tabulate(size) { i => + parent + .fieldTypes(newToOldFieldMapping(i)) + .castRename(rename.fields(i).typ) + } + + override lazy val fieldEmitTypes: IndexedSeq[EmitType] = + Array.tabulate(size) { i => + parent + .fieldEmitTypes(newToOldFieldMapping(i)) + .copy(st = fieldTypes(i)) + } + + override def virtualType: TBaseStruct = + rename + + override def fieldIdx(fieldName: String): Int = + rename.fieldIdx(fieldName) + + override def castRename(t: Type): SType = + SStructView(parent, restrict, rename = t.asInstanceOf[TStruct]) + + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = { + if (deepCopy) + throw new NotImplementedError("Deep copy on struct view") + + value.st match { + case s: SStructView if this == s && !deepCopy => + value + } + } + + override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = + parent.settableTupleTypes() + + override def fromSettables(settables: IndexedSeq[Settable[_]]): SStructViewSettable = + new SStructViewSettable( + this, + parent.fromSettables(settables).asInstanceOf[SBaseStructSettable], + ) + + override def fromValues(values: IndexedSeq[Value[_]]): SStructViewValue = + new SStructViewValue(this, parent.fromValues(values).asInstanceOf[SBaseStructValue]) + + override def copiedType: SType = + if (virtualType.size < 64) + SStackStruct(virtualType, fieldEmitTypes.map(_.copiedType)) + else { + val ct = SBaseStructPointer(storageType().asInstanceOf[PCanonicalStruct]) + assert(ct.virtualType == virtualType, s"ct=$ct, this=$this") + ct + } + + def storageType(): PType = { + val pt = PCanonicalStruct( + required = false, + args = rename.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _*, + ) + assert(pt.virtualType == virtualType, s"pt=$pt, this=$this") + pt + } + + // aspirational implementation + // def storageType(): PType = StoredSTypePType(this, false) + + override def containsPointers: Boolean = + parent.containsPointers + + override def equals(obj: Any): Boolean = + obj match { + case s: SStructView => + rename == s.rename && + newToOldFieldMapping == s.newToOldFieldMapping && + parent == s.parent // todo test isIsomorphicTo + case _ => + false + } +} + +class SStructViewValue(val st: SStructView, val prev: SBaseStructValue) extends SBaseStructValue { + + override lazy val valueTuple: IndexedSeq[Value[_]] = + prev.valueTuple + + override def subset(fieldNames: String*): SBaseStructValue = + new SStructViewValue(SStructView.subset(fieldNames.toIndexedSeq, st), prev) + + override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = + prev + .loadField(cb, st.newToOldFieldMapping(fieldIdx)) + .map(cb)(_.castRename(st.virtualType.fields(fieldIdx).typ)) + + override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] = + prev.isFieldMissing(cb, st.newToOldFieldMapping(fieldIdx)) +} + +final class SStructViewSettable(st: SStructView, prev: SBaseStructSettable) + extends SStructViewValue(st, prev) with SBaseStructSettable { + override def subset(fieldNames: String*): SBaseStructValue = + new SStructViewSettable(SStructView.subset(fieldNames.toIndexedSeq, st), prev) + + override def settableTuple(): IndexedSeq[Settable[_]] = + prev.settableTuple() + + override def store(cb: EmitCodeBuilder, pv: SValue): Unit = + prev.store(cb, pv.asInstanceOf[SStructViewValue].prev) +} diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SSubsetStruct.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SSubsetStruct.scala deleted file mode 100644 index c0de0f2f6df..00000000000 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SSubsetStruct.scala +++ /dev/null @@ -1,144 +0,0 @@ -package is.hail.types.physical.stypes.concrete - -import is.hail.annotations.Region -import is.hail.asm4s.{Settable, TypeInfo, Value} -import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} -import is.hail.types.physical.{PCanonicalStruct, PType} -import is.hail.types.physical.stypes.{EmitType, SType, SValue} -import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue} -import is.hail.types.virtual.{Field, TStruct, Type} - -class SSubsetStruct( - private val parent: SBaseStruct, - private val fieldNames: IndexedSeq[String], -) extends SBaseStruct { - - override val size: Int = fieldNames.size - - val _fieldIdx: Map[String, Int] = fieldNames.zipWithIndex.toMap - - lazy val newToOldFieldMapping: Map[Int, Int] = { - val parentFieldIdx = parent.virtualType.asInstanceOf[TStruct] - _fieldIdx.map { case (f, i) => i -> parentFieldIdx.fieldIdx(f) } - } - - override lazy val fieldTypes: IndexedSeq[SType] = - Array.tabulate(size)(i => parent.fieldTypes(newToOldFieldMapping(i))) - - override lazy val fieldEmitTypes: IndexedSeq[EmitType] = - Array.tabulate(size)(i => parent.fieldEmitTypes(newToOldFieldMapping(i))) - - override lazy val virtualType: TStruct = { - val vparentTypes = parent.virtualType.asInstanceOf[TStruct].types - TStruct(fieldNames.zipWithIndex.map { case (f, i) => - Field(f, vparentTypes(newToOldFieldMapping(i)), i) - }) - } - - override def fieldIdx(fieldName: String): Int = - _fieldIdx(fieldName) - - override def castRename(t: Type): SType = { - val newVirtualType = t.asInstanceOf[TStruct] - val oldToNewFieldMapping = newToOldFieldMapping.map(n => n._2 -> n._1) - - // note we may have subsetted a parent struct{x,y,z} to struct{z} then renamed to struct{x} - // must only tell parent to castRename what it knows as `z`, leaving others intact - val newParent = parent.castRename( - TStruct(parent.virtualType.fields.zipWithIndex.map { case (f, i) => - oldToNewFieldMapping.get(i) match { - case Some(idx) => f.copy(typ = newVirtualType.types(idx)) - case None => f - } - }) - ) - - new SSubsetStruct(newParent.asInstanceOf[SBaseStruct], newVirtualType.fieldNames) { - override lazy val newToOldFieldMapping: Map[Int, Int] = - SSubsetStruct.this.newToOldFieldMapping - override lazy val virtualType: TStruct = - newVirtualType - } - } - - override def _coerceOrCopy( - cb: EmitCodeBuilder, - region: Value[Region], - value: SValue, - deepCopy: Boolean, - ): SValue = { - if (deepCopy) - throw new NotImplementedError("Deep copy on subset struct") - - value.st match { - case ss: SSubsetStruct if parent == ss.parent && fieldNames == ss.fieldNames && !deepCopy => - value - } - } - - override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = parent.settableTupleTypes() - - override def fromSettables(settables: IndexedSeq[Settable[_]]): SSubsetStructSettable = - new SSubsetStructSettable( - this, - parent.fromSettables(settables).asInstanceOf[SBaseStructSettable], - ) - - override def fromValues(values: IndexedSeq[Value[_]]): SSubsetStructValue = - new SSubsetStructValue(this, parent.fromValues(values).asInstanceOf[SBaseStructValue]) - - override def copiedType: SType = { - if (virtualType.size < 64) - SStackStruct(virtualType, fieldEmitTypes.map(_.copiedType)) - else { - val ct = SBaseStructPointer(PCanonicalStruct( - false, - virtualType.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _* - )) - assert(ct.virtualType == virtualType, s"ct=$ct, this=$this") - ct - } - } - - def storageType(): PType = { - val pt = PCanonicalStruct( - false, - virtualType.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _* - ) - assert(pt.virtualType == virtualType, s"pt=$pt, this=$this") - pt - } - - // aspirational implementation - // def storageType(): PType = StoredSTypePType(this, false) - - override def containsPointers: Boolean = parent.containsPointers - - override def equals(obj: Any): Boolean = - obj match { - case s: SSubsetStruct => - newToOldFieldMapping == s.newToOldFieldMapping && - parent.fieldTypes == s.parent.fieldTypes - case _ => - false - } -} - -class SSubsetStructValue(val st: SSubsetStruct, val prev: SBaseStructValue) - extends SBaseStructValue { - override lazy val valueTuple: IndexedSeq[Value[_]] = prev.valueTuple - - override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = - prev.loadField(cb, st.newToOldFieldMapping(fieldIdx)) - - override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] = - prev.isFieldMissing(cb, st.newToOldFieldMapping(fieldIdx)) -} - -final class SSubsetStructSettable(st: SSubsetStruct, prev: SBaseStructSettable) - extends SSubsetStructValue(st, prev) with SBaseStructSettable { - override def settableTuple(): IndexedSeq[Settable[_]] = prev.settableTuple() - - override def store(cb: EmitCodeBuilder, pv: SValue): Unit = - prev.store(cb, pv.asInstanceOf[SSubsetStructValue].prev) -} diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBaseStruct.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBaseStruct.scala index 11061b8dfdd..51353e9fb67 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBaseStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBaseStruct.scala @@ -70,7 +70,7 @@ trait SBaseStructValue extends SValue { loadField(cb, st.fieldIdx(fieldName)) def subset(fieldNames: String*): SBaseStructValue = - new SSubsetStructValue(new SSubsetStruct(st, fieldNames.toIndexedSeq), this) + new SStructViewValue(SStructView.subset(fieldNames.toIndexedSeq, st), this) override def hash(cb: EmitCodeBuilder): SInt32Value = { val hash_result = cb.newLocal[Int]("hash_result_struct", 1) diff --git a/hail/src/main/scala/is/hail/types/virtual/TBaseStruct.scala b/hail/src/main/scala/is/hail/types/virtual/TBaseStruct.scala index c47efe10025..15b1fe05b6e 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TBaseStruct.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TBaseStruct.scala @@ -32,7 +32,8 @@ abstract class TBaseStruct extends Type { def fields: IndexedSeq[Field] - lazy val fieldIdx: collection.Map[String, Int] = toMapFast(fields)(_.name, _.index) + lazy val fieldIdx: collection.Map[String, Int] = + toMapFast(fields)(_.name, _.index) override def children: IndexedSeq[Type] = types diff --git a/hail/src/test/scala/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala b/hail/src/test/scala/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala new file mode 100644 index 00000000000..2d47415b672 --- /dev/null +++ b/hail/src/test/scala/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala @@ -0,0 +1,60 @@ +package is.hail.types.physical.stypes.concrete + +import is.hail.HailSuite +import is.hail.types.physical.stypes.SType +import is.hail.types.physical.stypes.interfaces.SBaseStruct +import is.hail.types.virtual.{TInt32, TInt64, TStruct} +import is.hail.utils.FastSeq + +import org.testng.annotations.Test + +class SStructViewSuite extends HailSuite { + + val xyz: SBaseStruct = + SType.canonical( + TStruct( + "x" -> TInt32, + "y" -> TInt64, + "z" -> TStruct("a" -> TInt32), + ) + ).asInstanceOf[SBaseStruct] + + @Test def testCastRename(): Unit = { + val newtype = TStruct("x" -> TStruct("b" -> TInt32)) + + val expected = + SStructView( + parent = xyz, + restrict = xyz.virtualType.asInstanceOf[TStruct].typeAfterSelectNames(FastSeq("z")), + rename = newtype, + ) + + assert(SStructView.subset(FastSeq("z"), xyz).castRename(newtype) == expected) + } + + @Test def testSubsetRenameSubset(): Unit = { + val subset = + SStructView.subset( + FastSeq("x"), + SStructView.subset(FastSeq("x", "z"), xyz) + .castRename(TStruct("y" -> TInt32, "x" -> TStruct("b" -> TInt32))) + .asInstanceOf[SBaseStruct], + ) + + val expected = + SStructView( + parent = xyz, + restrict = xyz.virtualType.asInstanceOf[TStruct].typeAfterSelectNames(FastSeq("z")), + rename = TStruct("x" -> TStruct("b" -> TInt32)), + ) + + assert(subset == expected) + } + + @Test def testAssertIsomorphism(): Unit = + intercept[AssertionError] { + SStructView.subset(FastSeq("x", "y"), xyz) + .castRename(TStruct("x" -> TInt64, "x" -> TInt32)) + } + +} diff --git a/hail/src/test/scala/is/hail/types/physical/stypes/concrete/SSubsetStructSuite.scala b/hail/src/test/scala/is/hail/types/physical/stypes/concrete/SSubsetStructSuite.scala deleted file mode 100644 index 1e80afcbafa..00000000000 --- a/hail/src/test/scala/is/hail/types/physical/stypes/concrete/SSubsetStructSuite.scala +++ /dev/null @@ -1,43 +0,0 @@ -package is.hail.types.physical.stypes.concrete - -import is.hail.HailSuite -import is.hail.types.physical.stypes.SType -import is.hail.types.physical.stypes.interfaces.SBaseStruct -import is.hail.types.virtual.{TInt32, TInt64, TStruct} -import is.hail.utils.FastSeq - -import org.testng.annotations.Test - -class SSubsetStructSuite extends HailSuite { - - @Test def testCastRename(): Unit = { - val sparent = SType.canonical( - TStruct( - "x" -> TInt32, - "y" -> TInt64, - "z" -> TStruct("a" -> TInt32), - ) - ) - - val subset = - new SSubsetStruct(sparent.asInstanceOf[SBaseStruct], FastSeq("z")) - .castRename(TStruct("x" -> TStruct("b" -> TInt32))) - .asInstanceOf[SSubsetStruct] - - val expected = - new SSubsetStruct( - SType.canonical( - TStruct( - "z" -> TInt32, - "y" -> TInt64, - "x" -> TStruct("b" -> TInt32), - ) - ).asInstanceOf[SBaseStruct], - FastSeq("x"), - ) - - assert(subset == expected) - assert(subset.virtualType == expected.virtualType) - } - -}