From 466f3c32c7fed972dc9bf45834669fdff038224c Mon Sep 17 00:00:00 2001 From: John Compitello Date: Thu, 18 Mar 2021 12:23:34 -0400 Subject: [PATCH] [query] blanczos_pca dont do extra loading work (#10201) * Use the checkpointed table from mt_to_table_of_ndarray to avoid recomputing mt * Keep extra row fields from being included --- hail/python/hail/experimental/table_ndarray_utils.py | 9 ++++++--- hail/python/hail/methods/pca.py | 12 +++++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/hail/python/hail/experimental/table_ndarray_utils.py b/hail/python/hail/experimental/table_ndarray_utils.py index e107db97520..7d9fcef5cef 100644 --- a/hail/python/hail/experimental/table_ndarray_utils.py +++ b/hail/python/hail/experimental/table_ndarray_utils.py @@ -3,7 +3,7 @@ from hail.utils.java import Env -def mt_to_table_of_ndarray(entry_expr, block_size=16): +def mt_to_table_of_ndarray(entry_expr, block_size=16, return_checkpointed_table_also=False): check_entry_indexed('mt_to_table_of_ndarray/entry_expr', entry_expr) mt = matrix_table_source('mt_to_table_of_ndarray/entry_expr', entry_expr) @@ -35,8 +35,11 @@ def get_even_partitioning(ht, partition_size, total_num_rows): ht = ht.checkpoint(temp_file_name) num_rows = ht.count() new_partitioning = get_even_partitioning(ht, block_size, num_rows) - ht = hl.read_table(temp_file_name, _intervals=new_partitioning) + new_part_ht = hl.read_table(temp_file_name, _intervals=new_partitioning) - grouped = ht._group_within_partitions("groups", block_size) + grouped = new_part_ht._group_within_partitions("groups", block_size) A = grouped.select(ndarray=hl.nd.array(grouped.groups.map(lambda group: group.xs))) + + if return_checkpointed_table_also: + return A, ht return A diff --git a/hail/python/hail/methods/pca.py b/hail/python/hail/methods/pca.py index 53583baf123..db8f42769f4 100644 --- a/hail/python/hail/methods/pca.py +++ b/hail/python/hail/methods/pca.py @@ -300,10 +300,10 @@ def _blanczos_pca(entry_expr, k=10, compute_loadings=False, q_iterations=2, over (:obj:`list` of :obj:`float`, :class:`.Table`, :class:`.Table`) List of eigenvalues, table with column scores, table with row loadings. """ - + check_entry_indexed('mt_to_table_of_ndarray/entry_expr', entry_expr) mt = matrix_table_source('pca/entry_expr', entry_expr) - A = mt_to_table_of_ndarray(entry_expr, block_size) + A, ht = mt_to_table_of_ndarray(entry_expr, block_size, return_checkpointed_table_also=True) A = A.persist() # Set Parameters @@ -365,10 +365,12 @@ def hailBlanczos(A, G, k, q): cols_and_scores = hl.zip(A.index_globals().cols, hail_array_scores).map(lambda tup: tup[0].annotate(scores=tup[1])) st = hl.Table.parallelize(cols_and_scores, key=list(mt.col_key)) - lt = mt.rows().select() + lt = ht.select() lt = lt.annotate_globals(U=U) - lt = lt.add_index() - lt = lt.annotate(loadings=lt.U[lt.idx, :]._data_array()).select_globals() + idx_name = '_tmp_pca_loading_index' + lt = lt.add_index(idx_name) + lt = lt.annotate(loadings=lt.U[lt[idx_name], :]._data_array()).select_globals() + lt = lt.drop(lt[idx_name]) if compute_loadings: return eigens, st, lt