Skip to content

Commit

Permalink
[query] blanczos_pca dont do extra loading work (#10201)
Browse files Browse the repository at this point in the history
* Use the checkpointed table from mt_to_table_of_ndarray to avoid recomputing mt

* Keep extra row fields from being included
  • Loading branch information
johnc1231 authored Mar 18, 2021
1 parent 022d02f commit 466f3c3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
9 changes: 6 additions & 3 deletions hail/python/hail/experimental/table_ndarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hail.utils.java import Env


def mt_to_table_of_ndarray(entry_expr, block_size=16):
def mt_to_table_of_ndarray(entry_expr, block_size=16, return_checkpointed_table_also=False):
check_entry_indexed('mt_to_table_of_ndarray/entry_expr', entry_expr)
mt = matrix_table_source('mt_to_table_of_ndarray/entry_expr', entry_expr)

Expand Down Expand Up @@ -35,8 +35,11 @@ def get_even_partitioning(ht, partition_size, total_num_rows):
ht = ht.checkpoint(temp_file_name)
num_rows = ht.count()
new_partitioning = get_even_partitioning(ht, block_size, num_rows)
ht = hl.read_table(temp_file_name, _intervals=new_partitioning)
new_part_ht = hl.read_table(temp_file_name, _intervals=new_partitioning)

grouped = ht._group_within_partitions("groups", block_size)
grouped = new_part_ht._group_within_partitions("groups", block_size)
A = grouped.select(ndarray=hl.nd.array(grouped.groups.map(lambda group: group.xs)))

if return_checkpointed_table_also:
return A, ht
return A
12 changes: 7 additions & 5 deletions hail/python/hail/methods/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,10 @@ def _blanczos_pca(entry_expr, k=10, compute_loadings=False, q_iterations=2, over
(:obj:`list` of :obj:`float`, :class:`.Table`, :class:`.Table`)
List of eigenvalues, table with column scores, table with row loadings.
"""

check_entry_indexed('mt_to_table_of_ndarray/entry_expr', entry_expr)
mt = matrix_table_source('pca/entry_expr', entry_expr)

A = mt_to_table_of_ndarray(entry_expr, block_size)
A, ht = mt_to_table_of_ndarray(entry_expr, block_size, return_checkpointed_table_also=True)
A = A.persist()

# Set Parameters
Expand Down Expand Up @@ -365,10 +365,12 @@ def hailBlanczos(A, G, k, q):
cols_and_scores = hl.zip(A.index_globals().cols, hail_array_scores).map(lambda tup: tup[0].annotate(scores=tup[1]))
st = hl.Table.parallelize(cols_and_scores, key=list(mt.col_key))

lt = mt.rows().select()
lt = ht.select()
lt = lt.annotate_globals(U=U)
lt = lt.add_index()
lt = lt.annotate(loadings=lt.U[lt.idx, :]._data_array()).select_globals()
idx_name = '_tmp_pca_loading_index'
lt = lt.add_index(idx_name)
lt = lt.annotate(loadings=lt.U[lt[idx_name], :]._data_array()).select_globals()
lt = lt.drop(lt[idx_name])

if compute_loadings:
return eigens, st, lt
Expand Down

0 comments on commit 466f3c3

Please sign in to comment.