Skip to content

Commit

Permalink
Generalise SSubsetStruct to views on SBaseStructs
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Feb 6, 2024
1 parent 65bf923 commit 2dd1fa3
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 196 deletions.
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
iec.map(cb)(pc => cast(cb, pc))
case CastRename(v, _typ) =>
emitI(v)
.map(cb)(pc => pc.st.castRename(_typ).fromValues(pc.valueTuple))
.map(cb)(_.castRename(_typ))
case NA(typ) =>
IEmitCode.missing(cb, SUnreachable.fromVirtualType(typ).defaultValue)
case IsNA(v) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ abstract class PBaseStruct extends PType {

def size: Int = fields.length

def isIsomorphicTo(other: PBaseStruct) =
def isIsomorphicTo(other: PBaseStruct): Boolean =
this.fields.size == other.fields.size && this.isCompatibleWith(other)

def _toPretty: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import is.hail.asm4s.{Code, Value}
import is.hail.backend.HailStateManager
import is.hail.expr.ir.EmitCodeBuilder
import is.hail.types.physical.stypes.SValue
import is.hail.types.physical.stypes.concrete.SSubsetStruct
import is.hail.types.physical.stypes.interfaces.SBaseStructValue
import is.hail.types.physical.stypes.concrete.SStructView
import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructValue}
import is.hail.types.virtual.TStruct
import is.hail.utils._

Expand Down Expand Up @@ -137,8 +137,8 @@ final case class PSubsetStruct(ps: PStruct, _fieldNames: IndexedSeq[String]) ext
): Long =
throw new UnsupportedOperationException

def sType: SSubsetStruct =
new SSubsetStruct(ps.sType, _fieldNames)
def sType: SBaseStruct =
SStructView.subset(_fieldNames, ps.sType)

def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean)
: Value[Long] =
Expand Down
4 changes: 4 additions & 0 deletions hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import is.hail.expr.ir.EmitCodeBuilder
import is.hail.types.physical.stypes.concrete.SRNGStateValue
import is.hail.types.physical.stypes.interfaces._
import is.hail.types.physical.stypes.primitives._
import is.hail.types.virtual.Type

object SCode {
def add(cb: EmitCodeBuilder, left: SValue, right: SValue, required: Boolean): SValue = {
Expand Down Expand Up @@ -114,6 +115,9 @@ trait SValue {
throw new UnsupportedOperationException(s"Stype $st has no hashcode")

def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value

def castRename(t: Type): SValue =
st.castRename(t).fromValues(valueTuple)
}

trait SSettable extends SValue {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class SStackStructValue(val st: SStackStruct, val values: IndexedSeq[EmitValue])
override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] =
values(fieldIdx).m

override def subset(fieldNames: String*): SStackStructValue = {
override def subset(fieldNames: String*): SBaseStructValue = {
val newToOld = fieldNames.map(st.fieldIdx).toArray
val oldVType = st.virtualType.asInstanceOf[TStruct]
val newVirtualType = TStruct(newToOld.map(i => (oldVType.fieldNames(i), oldVType.types(i))): _*)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package is.hail.types.physical.stypes.concrete

import is.hail.annotations.Region
import is.hail.asm4s.{Settable, TypeInfo, Value}
import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode}
import is.hail.types.physical.{PCanonicalStruct, PType}
import is.hail.types.physical.stypes.{EmitType, SType, SValue}
import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue}
import is.hail.types.virtual.{TBaseStruct, TStruct, Type}

object SStructView {
def subset(fieldnames: IndexedSeq[String], struct: SBaseStruct): SStructView =
struct match {
case s @ SStructView(parent, restrict, rename) =>
SStructView(
parent,
restrict.typeAfterSelect(fieldnames.map(s.fieldIdx)),
rename.typeAfterSelectNames(fieldnames),
)

case s =>
val restrict = s.virtualType.asInstanceOf[TStruct].typeAfterSelectNames(fieldnames)
SStructView(s, restrict, restrict)
}
}

// A 'view' on `SBaseStruct`s, ie one that presents an upcast and/or renamed facade on another
final case class SStructView(parent: SBaseStruct, restrict: TStruct, rename: TStruct)
extends SBaseStruct {

assert(
restrict.canCastTo(rename),
f"""Upcast operations are not isomorphic
| restrict: '${restrict._toPretty}'
| rename: '${rename._toPretty}'
|""".stripMargin,
)

override val size: Int =
restrict.fields.length

lazy val newToOldFieldMapping: Map[Int, Int] =
restrict.fields.view.map(f => f.index -> parent.fieldIdx(f.name)).toMap

override lazy val fieldTypes: IndexedSeq[SType] =
Array.tabulate(size) { i =>
parent
.fieldTypes(newToOldFieldMapping(i))
.castRename(rename.fields(i).typ)
}

override lazy val fieldEmitTypes: IndexedSeq[EmitType] =
Array.tabulate(size) { i =>
parent
.fieldEmitTypes(newToOldFieldMapping(i))
.copy(st = fieldTypes(i))
}

override def virtualType: TBaseStruct =
rename

override def fieldIdx(fieldName: String): Int =
rename.fieldIdx(fieldName)

override def castRename(t: Type): SType =
SStructView(parent, restrict, rename = t.asInstanceOf[TStruct])

override def _coerceOrCopy(
cb: EmitCodeBuilder,
region: Value[Region],
value: SValue,
deepCopy: Boolean,
): SValue = {
if (deepCopy)
throw new NotImplementedError("Deep copy on struct view")

value.st match {
case s: SStructView if this == s && !deepCopy =>
value
}
}

override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] =
parent.settableTupleTypes()

override def fromSettables(settables: IndexedSeq[Settable[_]]): SStructViewSettable =
new SStructViewSettable(
this,
parent.fromSettables(settables).asInstanceOf[SBaseStructSettable],
)

override def fromValues(values: IndexedSeq[Value[_]]): SStructViewValue =
new SStructViewValue(this, parent.fromValues(values).asInstanceOf[SBaseStructValue])

override def copiedType: SType =
if (virtualType.size < 64)
SStackStruct(virtualType, fieldEmitTypes.map(_.copiedType))
else {
val ct = SBaseStructPointer(storageType().asInstanceOf[PCanonicalStruct])
assert(ct.virtualType == virtualType, s"ct=$ct, this=$this")
ct
}

def storageType(): PType = {
val pt = PCanonicalStruct(
required = false,
args = rename.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _*,
)
assert(pt.virtualType == virtualType, s"pt=$pt, this=$this")
pt
}

// aspirational implementation
// def storageType(): PType = StoredSTypePType(this, false)

override def containsPointers: Boolean =
parent.containsPointers

override def equals(obj: Any): Boolean =
obj match {
case s: SStructView =>
rename == s.rename &&
newToOldFieldMapping == s.newToOldFieldMapping &&
parent == s.parent // todo test isIsomorphicTo
case _ =>
false
}
}

class SStructViewValue(val st: SStructView, val prev: SBaseStructValue) extends SBaseStructValue {

override lazy val valueTuple: IndexedSeq[Value[_]] =
prev.valueTuple

override def subset(fieldNames: String*): SBaseStructValue =
new SStructViewValue(SStructView.subset(fieldNames.toIndexedSeq, st), prev)

override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode =
prev
.loadField(cb, st.newToOldFieldMapping(fieldIdx))
.map(cb)(_.castRename(st.virtualType.fields(fieldIdx).typ))

override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] =
prev.isFieldMissing(cb, st.newToOldFieldMapping(fieldIdx))
}

final class SStructViewSettable(st: SStructView, prev: SBaseStructSettable)
extends SStructViewValue(st, prev) with SBaseStructSettable {
override def subset(fieldNames: String*): SBaseStructValue =
new SStructViewSettable(SStructView.subset(fieldNames.toIndexedSeq, st), prev)

override def settableTuple(): IndexedSeq[Settable[_]] =
prev.settableTuple()

override def store(cb: EmitCodeBuilder, pv: SValue): Unit =
prev.store(cb, pv.asInstanceOf[SStructViewValue].prev)
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ trait SBaseStructValue extends SValue {
loadField(cb, st.fieldIdx(fieldName))

def subset(fieldNames: String*): SBaseStructValue =
new SSubsetStructValue(new SSubsetStruct(st, fieldNames.toIndexedSeq), this)
new SStructViewValue(SStructView.subset(fieldNames.toIndexedSeq, st), this)

override def hash(cb: EmitCodeBuilder): SInt32Value = {
val hash_result = cb.newLocal[Int]("hash_result_struct", 1)
Expand Down
3 changes: 2 additions & 1 deletion hail/src/main/scala/is/hail/types/virtual/TBaseStruct.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ abstract class TBaseStruct extends Type {

def fields: IndexedSeq[Field]

lazy val fieldIdx: collection.Map[String, Int] = toMapFast(fields)(_.name, _.index)
lazy val fieldIdx: collection.Map[String, Int] =
toMapFast(fields)(_.name, _.index)

override def children: IndexedSeq[Type] = types

Expand Down
Loading

0 comments on commit 2dd1fa3

Please sign in to comment.