Skip to content

Commit

Permalink
persist = checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Oct 17, 2022
1 parent a180df4 commit 49ba543
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 101 deletions.
29 changes: 22 additions & 7 deletions hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def fatal_error_from_java_error_triplet(short_message, expanded_message, error_i


class Backend(abc.ABC):
@abc.abstractmethod
def __init__(self):
self._persisted_locations = dict()

@abc.abstractmethod
def stop(self):
pass
Expand Down Expand Up @@ -127,18 +131,29 @@ def index_bgen(self,
def import_fam(self, path: str, quant_pheno: bool, delimiter: str, missing: str):
pass

def persist_table(self, t, storage_level):
# FIXME: this can't possibly be right.
return t
def persist_table(self, t):
from hail.context import TemporaryFilename
tf = TemporaryFilename(prefix='persist_table')
self._persisted_locations[t] = tf
return t.checkpoint(tf.__enter__())

def unpersist_table(self, t):
return t
try:
self._persisted_locations[t].__exit__(None, None, None)
except KeyError as err:
raise ValueError(f'{t} is not persisted') from err

def persist_matrix_table(self, mt, storage_level):
return mt
def persist_matrix_table(self, mt):
from hail.context import TemporaryFilename
tf = TemporaryFilename(prefix='persist_matrix_table')
self._persisted_locations[mt] = tf
return mt.checkpoint(tf.__enter__())

def unpersist_matrix_table(self, mt):
return mt
try:
self._persisted_locations[mt].__exit__(None, None, None)
except KeyError as err:
raise ValueError(f'{mt} is not persisted') from err

def unpersist_block_matrix(self, id):
pass
Expand Down
1 change: 1 addition & 0 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Py4JBackend(Backend):

@abc.abstractmethod
def __init__(self):
super(Py4JBackend, self).__init__()
import base64

def decode_bytearray(encoded):
Expand Down
26 changes: 2 additions & 24 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import yaml
from pathlib import Path

from hail.context import TemporaryDirectory, tmp_dir, TemporaryFilename, revision, _TemporaryFilenameManager
from hail.context import TemporaryDirectory, tmp_dir, TemporaryFilename, revision
from hail.utils import FatalError
from hail.expr.types import HailType, dtype, ttuple, tvoid
from hail.expr.table_type import ttable
Expand Down Expand Up @@ -273,6 +273,7 @@ def __init__(self,
worker_cores: Optional[Union[int, str]],
worker_memory: Optional[str],
name_prefix: str):
super(ServiceBackend, self).__init__()
self.billing_project = billing_project
self._sync_fs = sync_fs
self._async_fs = async_fs
Expand All @@ -290,7 +291,6 @@ def __init__(self,
self.worker_cores = worker_cores
self.worker_memory = worker_memory
self.name_prefix = name_prefix
self._persisted_locations: Dict[Any, _TemporaryFilenameManager] = dict()

def debug_info(self) -> Dict[str, Any]:
return {
Expand Down Expand Up @@ -701,28 +701,6 @@ def persist_expression(self, expr):
write_expression(expr, fname)
return read_expression(fname, _assert_type=expr.dtype)

def persist_table(self, t, storage_level):
tf = TemporaryFilename(prefix='persist_table')
self._persisted_locations[t] = tf
return t.checkpoint(tf.__enter__())

def unpersist_table(self, t):
try:
self._persisted_locations[t].__exit__(None, None, None)
except KeyError as err:
raise ValueError(f'{t} is not persisted') from err

def persist_matrix_table(self, mt, storage_level):
tf = TemporaryFilename(prefix='persist_matrix_table')
self._persisted_locations[mt] = tf
return mt.checkpoint(tf.__enter__())

def unpersist_matrix_table(self, mt):
try:
self._persisted_locations[mt].__exit__(None, None, None)
except KeyError as err:
raise ValueError(f'{mt} is not persisted') from err

def set_flags(self, **flags: str):
self.flags.update(flags)

Expand Down
13 changes: 0 additions & 13 deletions hail/python/hail/backend/spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,19 +300,6 @@ def matrix_type(self, mir):
jir = self._to_java_matrix_ir(mir)
return tmatrix._from_java(jir.typ())

def persist_table(self, t, storage_level):
return Table._from_java(self._jbackend.pyPersistTable(storage_level, self._to_java_table_ir(t._tir)))

def unpersist_table(self, t):
return Table._from_java(self._to_java_table_ir(t._tir).pyUnpersist())

def persist_matrix_table(self, mt, storage_level):
ir = mt._mir.handle_randomness(None, None)
return MatrixTable._from_java(self._jbackend.pyPersistMatrix(storage_level, self._to_java_matrix_ir(ir)))

def unpersist_matrix_table(self, mt):
return MatrixTable._from_java(self._to_java_matrix_ir(mt._mir).pyUnpersist())

def unpersist_block_matrix(self, id):
self._jhc.backend().unpersist(id)

Expand Down
4 changes: 2 additions & 2 deletions hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,7 +2095,7 @@ def aggregate_entries(self, expr, _localize=True):
analyze('MatrixTable.aggregate_entries', expr, self._global_indices, {self._row_axis, self._col_axis})
agg_ir = ir.MatrixAggregate(base._mir, expr._ir)
if _localize:
return Env.backend().execute(agg_ir)
return Env.backend().execute(ir.MakeTuple([agg_ir]))[0]
else:
return construct_expr(ir.LiftMeOut(agg_ir), expr.dtype)

Expand Down Expand Up @@ -3458,7 +3458,7 @@ def persist(self, storage_level: str = 'MEMORY_AND_DISK') -> 'MatrixTable':
:class:`.MatrixTable`
Persisted dataset.
"""
return Env.backend().persist_matrix_table(self, storage_level)
return Env.backend().persist_matrix_table(self)

def unpersist(self) -> 'MatrixTable':
"""
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,7 +2065,7 @@ def persist(self, storage_level='MEMORY_AND_DISK') -> 'Table':
:class:`.Table`
Persisted table.
"""
return Env.backend().persist_table(self, storage_level)
return Env.backend().persist_table(self)

def unpersist(self) -> 'Table':
"""
Expand Down
1 change: 0 additions & 1 deletion hail/python/test/hail/methods/test_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def make_mt(rows):
n_discordant=0),
]

@fails_service_backend()
def test_concordance_no_values_doesnt_error(self):
dataset = get_dataset().filter_rows(False)
_, cols_conc, rows_conc = hl.concordance(dataset, dataset)
Expand Down
1 change: 1 addition & 0 deletions hail/python/test/hail/methods/test_statgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,7 @@ def test_warn_if_no_intercept(self):
self.assertFalse(hl.methods.statgen._warn_if_no_intercept('', [intercept] + covariates))

@fails_service_backend()
@fails_local_backend()
def test_regression_field_dependence(self):
mt = hl.utils.range_matrix_table(10, 10)
mt = mt.annotate_cols(c1 = hl.literal([x % 2 == 0 for x in range(10)])[mt.col_idx], c2 = hl.rand_norm(0, 1))
Expand Down
7 changes: 7 additions & 0 deletions hail/python/test/hail/table/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,6 +1784,13 @@ def test_read_partitions_with_missing_key():
assert hl.read_table(path, _n_partitions=10).n_partitions() == 1 # one key => one partition


def test_empty_tree_aggregate():
ht = hl.utils.range_table(100, 3)
path = new_temp_file()
ht = ht.checkpoint(path).filter(False)
assert ht.aggregate(hl.agg.counter(ht.idx)) == {}


def test_interval_filter_partitions():
ht = hl.utils.range_table(100, 3)
path = new_temp_file()
Expand Down
32 changes: 0 additions & 32 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -543,38 +543,6 @@ class SparkBackend(
}
}

def pyPersistMatrix(storageLevel: String, mir: MatrixIR): MatrixIR = {
ExecutionTimer.logTime("SparkBackend.pyPersistMatrix") { timer =>
val level = try {
StorageLevel.fromString(storageLevel)
} catch {
case e: IllegalArgumentException =>
fatal(s"unknown StorageLevel: $storageLevel")
}

withExecuteContext(timer, selfContainedExecution = false) { ctx =>
val tv = Interpret(mir, ctx, optimize = true)
MatrixLiteral(mir.typ, TableLiteral(tv.persist(ctx, level), ctx.theHailClassLoader))
}
}
}

def pyPersistTable(storageLevel: String, tir: TableIR): TableIR = {
ExecutionTimer.logTime("SparkBackend.pyPersistTable") { timer =>
val level = try {
StorageLevel.fromString(storageLevel)
} catch {
case e: IllegalArgumentException =>
fatal(s"unknown StorageLevel: $storageLevel")
}

withExecuteContext(timer, selfContainedExecution = false) { ctx =>
val tv = Interpret(tir, ctx, optimize = true)
TableLiteral(tv.persist(ctx, level), ctx.theHailClassLoader)
}
}
}

def pyToDF(tir: TableIR): DataFrame = {
ExecutionTimer.logTime("SparkBackend.pyToDF") { timer =>
withExecuteContext(timer, selfContainedExecution = false) { ctx =>
Expand Down
50 changes: 29 additions & 21 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -624,16 +624,19 @@ object LowerTableIR {

val distAggStatesRef = Ref(genUID(), TArray(TString))


def combineGroup(partArrayRef: IR): IR = {
def combineGroup(partArrayRef: IR, useInitStates: Boolean): IR = {
Begin(FastIndexedSeq(
bindIR(ReadValue(ArrayRef(partArrayRef, 0), codecSpec, codecSpec.encodedVirtualType)) { serializedTuple =>
Begin(
aggs.aggs.zipWithIndex.map { case (sig, i) =>
InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state)
})
if (useInitStates) {
initFromSerializedStates
} else {
bindIR(ReadValue(ArrayRef(partArrayRef, 0), codecSpec, codecSpec.encodedVirtualType)) { serializedTuple =>
Begin(
aggs.aggs.zipWithIndex.map { case (sig, i) =>
InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state)
})
}
},
forIR(StreamRange(1, ArrayLen(partArrayRef), 1, requiresMemoryManagementPerElement = true)) { fileIdx =>
forIR(StreamRange(if (useInitStates) 0 else 1, ArrayLen(partArrayRef), 1, requiresMemoryManagementPerElement = true)) { fileIdx =>

bindIR(ReadValue(ArrayRef(partArrayRef, fileIdx), codecSpec, codecSpec.encodedVirtualType)) { serializedTuple =>
Begin(
Expand All @@ -648,22 +651,27 @@ object LowerTableIR {
FastIndexedSeq[(String, IR)](currentAggStates.name -> collected, iterNumber.name -> I32(0)),
If(ArrayLen(currentAggStates) <= I32(branchFactor),
currentAggStates,
Recur(treeAggFunction, FastIndexedSeq(CollectDistributedArray(mapIR(StreamGrouped(ToStream(currentAggStates), I32(branchFactor)))(x => ToArray(x)),
MakeStruct(FastSeq()), distAggStatesRef.name, genUID(),
RunAgg(
combineGroup(distAggStatesRef),
WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), codecSpec),
aggs.states
), strConcat(Str("iteration="), invoke("str", TString, iterNumber), Str(", n_states="), invoke("str", TString, ArrayLen(currentAggStates))),
"table_tree_aggregate"
), iterNumber + 1), currentAggStates.typ)))
Recur(treeAggFunction,
FastIndexedSeq(
CollectDistributedArray(
mapIR(StreamGrouped(ToStream(currentAggStates), I32(branchFactor)))(x => ToArray(x)),
MakeStruct(FastSeq()),
distAggStatesRef.name,
genUID(),
RunAgg(
combineGroup(distAggStatesRef, false),
WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), codecSpec),
aggs.states
),
strConcat(Str("iteration="), invoke("str", TString, iterNumber), Str(", n_states="), invoke("str", TString, ArrayLen(currentAggStates))),
"table_tree_aggregate"),
iterNumber + 1),
currentAggStates.typ)))
) { finalParts =>
RunAgg(
combineGroup(finalParts),
combineGroup(finalParts, true),
Let("global", globals,
Let(
resultUID,
results,
Let(resultUID, results,
aggs.postAggIR)),
aggs.states
)
Expand Down

0 comments on commit 49ba543

Please sign in to comment.